Make TPUMirroredVariable
a subclass of MirroredVariable
.
PiperOrigin-RevId: 252370482
This commit is contained in:
parent
f070a4b01c
commit
8da25cbfe7
@ -485,23 +485,25 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, self._device_map, value, destinations)
|
||||
|
||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||
if len(devices) != 1:
|
||||
raise ValueError("Multiple devices are not supported for TPUStrategy")
|
||||
|
||||
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
|
||||
# Always performs the reduction on the TPU host.
|
||||
with ops.device(self._host_device):
|
||||
output = math_ops.add_n(value.values)
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
output *= (1. / len(value.values))
|
||||
|
||||
# If necessary, copy to requested destination.
|
||||
dest_canonical = device_util.canonicalize(devices[0])
|
||||
host_canonical = device_util.canonicalize(self._host_device)
|
||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||
|
||||
if dest_canonical != host_canonical:
|
||||
with ops.device(dest_canonical):
|
||||
output = array_ops.identity(output)
|
||||
if len(devices) == 1:
|
||||
# If necessary, copy to requested destination.
|
||||
dest_canonical = device_util.canonicalize(devices[0])
|
||||
host_canonical = device_util.canonicalize(self._host_device)
|
||||
|
||||
if dest_canonical != host_canonical:
|
||||
with ops.device(dest_canonical):
|
||||
output = array_ops.identity(output)
|
||||
else:
|
||||
output = cross_device_ops_lib.simple_broadcast(output, destinations)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -388,63 +388,73 @@ class DistributedDelegate(DistributedValues):
|
||||
# __getattr__ and @property. See b/120402273.
|
||||
return getattr(self.get(), name)
|
||||
|
||||
def _get_as_operand(self):
|
||||
"""Returns the value for operations for the current device.
|
||||
|
||||
Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
|
||||
value type within a replica context. They can, however, return a value that
|
||||
can be used by the operations below.
|
||||
"""
|
||||
return self.get()
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o): return self.get() + o
|
||||
def __radd__(self, o): return o + self.get()
|
||||
def __sub__(self, o): return self.get() - o
|
||||
def __rsub__(self, o): return o - self.get()
|
||||
def __mul__(self, o): return self.get() * o
|
||||
def __rmul__(self, o): return o * self.get()
|
||||
def __truediv__(self, o): return self.get() / o
|
||||
def __rtruediv__(self, o): return o / self.get()
|
||||
def __add__(self, o): return self._get_as_operand() + o
|
||||
def __radd__(self, o): return o + self._get_as_operand()
|
||||
def __sub__(self, o): return self._get_as_operand() - o
|
||||
def __rsub__(self, o): return o - self._get_as_operand()
|
||||
def __mul__(self, o): return self._get_as_operand() * o
|
||||
def __rmul__(self, o): return o * self._get_as_operand()
|
||||
def __truediv__(self, o): return self._get_as_operand() / o
|
||||
def __rtruediv__(self, o): return o / self._get_as_operand()
|
||||
|
||||
def __floordiv__(self, o):
|
||||
return self.get() // o
|
||||
return self._get_as_operand() // o
|
||||
|
||||
def __rfloordiv__(self, o): return o // self.get()
|
||||
def __mod__(self, o): return self.get() % o
|
||||
def __rmod__(self, o): return o % self.get()
|
||||
def __lt__(self, o): return self.get() < o
|
||||
def __le__(self, o): return self.get() <= o
|
||||
def __gt__(self, o): return self.get() > o
|
||||
def __ge__(self, o): return self.get() >= o
|
||||
def __and__(self, o): return self.get() & o
|
||||
def __rand__(self, o): return o & self.get()
|
||||
def __or__(self, o): return self.get() | o
|
||||
def __ror__(self, o): return o | self.get()
|
||||
def __xor__(self, o): return self.get() ^ o
|
||||
def __rxor__(self, o): return o ^ self.get()
|
||||
def __getitem__(self, o): return self.get()[o]
|
||||
def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo)
|
||||
def __rpow__(self, o): return pow(o, self.get())
|
||||
def __invert__(self): return ~self.get()
|
||||
def __neg__(self): return -self.get()
|
||||
def __abs__(self): return abs(self.get())
|
||||
def __rfloordiv__(self, o): return o // self._get_as_operand()
|
||||
def __mod__(self, o): return self._get_as_operand() % o
|
||||
def __rmod__(self, o): return o % self._get_as_operand()
|
||||
def __lt__(self, o): return self._get_as_operand() < o
|
||||
def __le__(self, o): return self._get_as_operand() <= o
|
||||
def __gt__(self, o): return self._get_as_operand() > o
|
||||
def __ge__(self, o): return self._get_as_operand() >= o
|
||||
def __and__(self, o): return self._get_as_operand() & o
|
||||
def __rand__(self, o): return o & self._get_as_operand()
|
||||
def __or__(self, o): return self._get_as_operand() | o
|
||||
def __ror__(self, o): return o | self._get_as_operand()
|
||||
def __xor__(self, o): return self._get_as_operand() ^ o
|
||||
def __rxor__(self, o): return o ^ self._get_as_operand()
|
||||
def __getitem__(self, o): return self._get_as_operand()[o]
|
||||
def __pow__(self, o, modulo=None):
|
||||
return pow(self._get_as_operand(), o, modulo)
|
||||
def __rpow__(self, o): return pow(o, self._get_as_operand())
|
||||
def __invert__(self): return ~self._get_as_operand()
|
||||
def __neg__(self): return -self._get_as_operand()
|
||||
def __abs__(self): return abs(self._get_as_operand())
|
||||
|
||||
def __div__(self, o):
|
||||
try:
|
||||
return self.get().__div__(o)
|
||||
return self._get_as_operand().__div__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rdiv__(self, o):
|
||||
try:
|
||||
return self.get().__rdiv__(o)
|
||||
return self._get_as_operand().__rdiv__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, o):
|
||||
try:
|
||||
return self.get().__matmul__(o)
|
||||
return self._get_as_operand().__matmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, o):
|
||||
try:
|
||||
return self.get().__rmatmul__(o)
|
||||
return self._get_as_operand().__rmatmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
@ -920,15 +930,19 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
return _MirroredSaveable(self, self.primary, name)
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
# Try to avoid assignments to and other mutations of MirroredVariable
|
||||
# state except through a DistributionStrategy.extended.update() call.
|
||||
assert not as_ref
|
||||
return ops.internal_convert_to_tensor(
|
||||
self.get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
|
||||
# Try to avoid assignments to and other mutations of MirroredVariable
|
||||
# state except through a DistributionStrategy.extended.update() call.
|
||||
assert not as_ref
|
||||
return ops.internal_convert_to_tensor(
|
||||
var.get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(MirroredVariable,
|
||||
@ -947,33 +961,17 @@ def _enclosing_tpu_context():
|
||||
|
||||
def is_distributed_variable(v):
|
||||
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
||||
return (isinstance(v, DistributedVariable)
|
||||
or isinstance(v, TPUMirroredVariable))
|
||||
return isinstance(v, DistributedVariable)
|
||||
|
||||
|
||||
# TODO(jhseu): Deduplicate code. We copy code because we don't want to
|
||||
# inherit from DistributedDelegate. DistributedDelegate will not work in a
|
||||
# tpu.replicate() because it assumes that you're in a device context where you
|
||||
# can operate on a single version of the variable, but a tpu.replicate()
|
||||
# operates on all variables and is replicated during a rewrite pass.
|
||||
class TPUMirroredVariable(variables_lib.Variable):
|
||||
class TPUMirroredVariable(MirroredVariable):
|
||||
"""Holds a map from device to TPU variables whose values are kept in sync."""
|
||||
|
||||
def __init__(
|
||||
self, strategy, device_map, values, aggregation, logical_device=None):
|
||||
assert isinstance(device_map, DeviceMap)
|
||||
self._distribute_strategy = strategy
|
||||
self._device_map = device_map
|
||||
self._values = tuple(values)
|
||||
if logical_device is None:
|
||||
logical_device = device_map.logical_device_from_values(self._values)
|
||||
self._logical_device = logical_device
|
||||
|
||||
# Use a weakref to make it easy to map from the contained values
|
||||
# to the container without introducing a reference cycle.
|
||||
for v in self._values:
|
||||
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
self._common_name = self.primary.name.split(":")[0]
|
||||
super(TPUMirroredVariable, self).__init__(
|
||||
strategy=strategy, device_map=device_map, values=values,
|
||||
aggregation=aggregation, logical_device=logical_device)
|
||||
|
||||
# Handle id is needed for get_replicated_var_handle to cache the variables
|
||||
# correctly since in eager mode different variables can have the same name.
|
||||
@ -982,28 +980,21 @@ class TPUMirroredVariable(variables_lib.Variable):
|
||||
else:
|
||||
self._handle_id = self._common_name
|
||||
|
||||
self._aggregation = aggregation
|
||||
# Needed for GradientTape
|
||||
self._trainable = self.primary.trainable
|
||||
# Typically like `DistributedVariable`, a `TPUMirroredVariable`'s
|
||||
# initializer is composed of the initializers of the components variables.
|
||||
# However, in some cases, such as when restoring from a checkpoint, we may
|
||||
# set the _initializer_op property on the entire `TPUMirroredVariable`.
|
||||
self._initializer_op = None
|
||||
def get(self, device=None):
|
||||
if (_enclosing_tpu_context() is None) or (device is not None):
|
||||
return super(TPUMirroredVariable, self).get(device=device)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"`TPUMirroredVariable.get()` is not supported within a TPU context.")
|
||||
|
||||
def _get(self, device=None):
|
||||
"""Returns the value for the current device or raises a ValueError."""
|
||||
if device is None:
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
if replica_context:
|
||||
return self._device_map.select_for_current_replica(
|
||||
self._values, replica_context)
|
||||
else:
|
||||
device = distribute_lib.get_update_device()
|
||||
if device is None:
|
||||
return self._get_cross_replica()
|
||||
device = device_util.canonicalize(device)
|
||||
return self._device_map.select_for_device(self._values, device)
|
||||
def _get_as_operand(self):
|
||||
return self.read_value()
|
||||
|
||||
def _get_closest(self):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUMirroredVariable, self)._get_closest()
|
||||
else:
|
||||
return self.primary
|
||||
|
||||
def numpy(self):
|
||||
if context.executing_eagerly():
|
||||
@ -1011,174 +1002,20 @@ class TPUMirroredVariable(variables_lib.Variable):
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def initialized_value(self):
|
||||
return self.primary.initialized_value()
|
||||
|
||||
@property
|
||||
def initial_value(self):
|
||||
return self.primary.initial_value
|
||||
|
||||
@property
|
||||
def primary(self):
|
||||
"""Returns a representative component."""
|
||||
return self._values[0]
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
return self._device_map.logical_to_actual_devices(self._logical_device)
|
||||
|
||||
@property
|
||||
def logical_device(self):
|
||||
return self._logical_device
|
||||
|
||||
@property
|
||||
def device_map(self):
|
||||
return self._device_map
|
||||
|
||||
# TODO(josh11b): Replace experimental_local_results with this?
|
||||
@property
|
||||
def values(self):
|
||||
return self._values
|
||||
|
||||
@property
|
||||
def distribute_strategy(self):
|
||||
return self._distribute_strategy
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o): return self.read_value() + o
|
||||
def __radd__(self, o): return o + self.read_value()
|
||||
def __sub__(self, o): return self.read_value() - o
|
||||
def __rsub__(self, o): return o - self.read_value()
|
||||
def __mul__(self, o): return self.read_value() * o
|
||||
def __rmul__(self, o): return o * self.read_value()
|
||||
def __truediv__(self, o): return self.read_value() / o
|
||||
def __rtruediv__(self, o): return o / self.read_value()
|
||||
def __floordiv__(self, o): return self.read_value() // o
|
||||
def __rfloordiv__(self, o): return o // self.read_value()
|
||||
def __mod__(self, o): return self.read_value() % o
|
||||
def __rmod__(self, o): return o % self.read_value()
|
||||
def __lt__(self, o): return self.read_value() < o
|
||||
def __le__(self, o): return self.read_value() <= o
|
||||
def __gt__(self, o): return self.read_value() > o
|
||||
def __ge__(self, o): return self.read_value() >= o
|
||||
def __and__(self, o): return self.read_value() & o
|
||||
def __rand__(self, o): return o & self.read_value()
|
||||
def __or__(self, o): return self.read_value() | o
|
||||
def __ror__(self, o): return o | self.read_value()
|
||||
def __xor__(self, o): return self.read_value() ^ o
|
||||
def __rxor__(self, o): return o ^ self.read_value()
|
||||
def __getitem__(self, o): return self.read_value()[o]
|
||||
def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
|
||||
def __rpow__(self, o): return pow(o, self.read_value())
|
||||
def __invert__(self): return ~self.read_value()
|
||||
def __neg__(self): return -self.read_value()
|
||||
def __abs__(self): return abs(self.read_value())
|
||||
|
||||
def __div__(self, o):
|
||||
try:
|
||||
return self.read_value().__div__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rdiv__(self, o):
|
||||
try:
|
||||
return self.read_value().__rdiv__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, o):
|
||||
try:
|
||||
return self.read_value().__matmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, o):
|
||||
try:
|
||||
return self.read_value().__rmatmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __str__(self):
|
||||
devices = self.devices
|
||||
debug_str = ",\n".join(" %d %s: %s" % (i, devices[i], self._values[i])
|
||||
for i in range(len(devices)))
|
||||
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
|
||||
|
||||
def __repr__(self):
|
||||
devices = self.devices
|
||||
debug_repr = ",\n".join(" %d %s: %r" % (i, devices[i], self._values[i])
|
||||
for i in range(len(devices)))
|
||||
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||
tpu_context = _enclosing_tpu_context()
|
||||
if tpu_context is not None:
|
||||
if tpu_context is None:
|
||||
return self._get_closest().handle
|
||||
else:
|
||||
return tpu_context.get_replicated_var_handle(
|
||||
self._handle_id, self._values)
|
||||
|
||||
device = distribute_lib.get_update_device()
|
||||
if device is None:
|
||||
return self.primary.handle
|
||||
return self._get(device=device).handle
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.handle.device
|
||||
|
||||
def eval(self, session=None):
|
||||
return self.primary.eval(session)
|
||||
|
||||
# The arguments to update() are automatically unwrapped so the update()
|
||||
# function would normally see regular variables, not MirroredVariables.
|
||||
# However, the update function can still operate on wrapped MirroredVariables
|
||||
# through object members, captured arguments, etc. This is more likely in an
|
||||
# update_non_slot() function (like OptimizerV2._finish), which can
|
||||
# update several non-slot variables in one call.
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
f = kwargs.pop("f")
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
if _enclosing_tpu_context() is not None:
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
|
||||
update_device = distribute_lib.get_update_device()
|
||||
# We are calling update on the mirrored variable in cross replica
|
||||
# context.
|
||||
if update_device is not None:
|
||||
# We are calling an assign function on the mirrored variable in cross
|
||||
# replica context.
|
||||
v = self._get(device=update_device)
|
||||
return f(v, *args, **kwargs)
|
||||
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
else:
|
||||
_assert_replica_context(self._distribute_strategy)
|
||||
# We are calling an assign function on the mirrored variable in replica
|
||||
# context.
|
||||
# We reduce the value we want to assign/add/sub. More details about how
|
||||
# we handle the different use cases can be found in the _reduce method.
|
||||
# We call the function on each of the mirrored variables with the
|
||||
# reduced value.
|
||||
if self._aggregation == vs.VariableAggregation.NONE:
|
||||
raise ValueError(_aggregation_error_msg.format(
|
||||
variable_type="TPUMirroredVariable"))
|
||||
|
||||
def merge_fn(strategy, value, *other_args, **other_kwargs):
|
||||
v = _apply_aggregation(strategy, value, self._aggregation, self)
|
||||
return strategy.extended.update(
|
||||
self, f, args=(v,) + other_args, kwargs=other_kwargs)
|
||||
|
||||
return distribution_strategy_context.get_replica_context().merge_call(
|
||||
merge_fn, args=args, kwargs=kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _handle_graph(self, handle):
|
||||
# Note: might have an eager tensor but not be executing eagerly when
|
||||
@ -1190,10 +1027,6 @@ class TPUMirroredVariable(variables_lib.Variable):
|
||||
with handle.graph.as_default():
|
||||
yield
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
return self._trainable
|
||||
|
||||
def _read_variable_op(self, parent_op=None):
|
||||
if self.trainable:
|
||||
tape.variable_accessed(self)
|
||||
@ -1208,156 +1041,64 @@ class TPUMirroredVariable(variables_lib.Variable):
|
||||
def read_value(self):
|
||||
return self._read_variable_op()
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
def assign_sub_fn(var, delta, *ar, **kw):
|
||||
del ar
|
||||
name = kw.pop("name", None)
|
||||
read_value = kw.pop("read_value", True)
|
||||
with self._handle_graph(var.handle):
|
||||
op = gen_resource_variable_ops.assign_sub_variable_op(
|
||||
var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
|
||||
name=name)
|
||||
if read_value:
|
||||
return self._read_variable_op(parent_op=op)
|
||||
return op
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
if (distribution_strategy_context.in_cross_replica_context()
|
||||
and (_enclosing_tpu_context() is not None)):
|
||||
f = kwargs.pop("f")
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
else:
|
||||
return super(TPUMirroredVariable, self)._assign_func(*args, **kwargs)
|
||||
|
||||
def _make_raw_assign_fn(self, raw_assign_fn):
|
||||
def assign_fn(var, value, *args, **kwargs):
|
||||
del args
|
||||
name = kwargs.pop("name", None)
|
||||
read_value = kwargs.pop("read_value", True)
|
||||
with self._handle_graph(var.handle):
|
||||
op = raw_assign_fn(
|
||||
var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
|
||||
name=name)
|
||||
return self._read_variable_op(parent_op=op) if read_value else op
|
||||
return assign_fn
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
assign_sub_fn = self._make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_sub_variable_op)
|
||||
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
def assign_add_fn(var, delta, *ar, **kw):
|
||||
del ar
|
||||
name = kw.pop("name", None)
|
||||
read_value = kw.pop("read_value", True)
|
||||
with self._handle_graph(var.handle):
|
||||
op = gen_resource_variable_ops.assign_add_variable_op(
|
||||
var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
|
||||
name=name)
|
||||
if read_value:
|
||||
return self._read_variable_op(parent_op=op)
|
||||
return op
|
||||
|
||||
assign_add_fn = self._make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_add_variable_op)
|
||||
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
def assign_fn(var, value, *ar, **kw):
|
||||
del ar
|
||||
name = kw.pop("name", None)
|
||||
read_value = kw.pop("read_value", True)
|
||||
with self._handle_graph(var.handle):
|
||||
op = gen_resource_variable_ops.assign_variable_op(
|
||||
var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
|
||||
name=name)
|
||||
if read_value:
|
||||
return self._read_variable_op(parent_op=op)
|
||||
return op
|
||||
|
||||
assign_fn = self._make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_variable_op)
|
||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def aggregation(self):
|
||||
return self._aggregation
|
||||
|
||||
@property
|
||||
def constraint(self):
|
||||
return self.primary.constraint
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
if self._initializer_op:
|
||||
init_op = self._initializer_op
|
||||
else:
|
||||
init_op = control_flow_ops.group(tuple(
|
||||
v.initializer for v in self._values))
|
||||
return init_op
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
return self.primary.graph
|
||||
|
||||
@property
|
||||
def _shared_name(self):
|
||||
return self._common_name
|
||||
|
||||
@property
|
||||
def _unique_id(self):
|
||||
return self.primary._unique_id # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.primary.name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.primary.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.primary.shape
|
||||
|
||||
def get_shape(self):
|
||||
return self.primary.get_shape()
|
||||
|
||||
def to_proto(self, export_scope=None):
|
||||
return self.primary.to_proto(export_scope=export_scope)
|
||||
|
||||
def _get_cross_replica(self):
|
||||
device = device_util.canonicalize(device_util.current())
|
||||
replica = self._device_map.replica_for_device(device)
|
||||
if replica is None:
|
||||
return self.primary
|
||||
return self._values[replica]
|
||||
|
||||
def _as_graph_element(self):
|
||||
# pylint: disable=protected-access
|
||||
if _enclosing_tpu_context() is None:
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
return self.primary._as_graph_element()
|
||||
return self._get()._as_graph_element()
|
||||
return None
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Overrides Trackable method.
|
||||
|
||||
This allows both name-based and object-based save and restore of
|
||||
MirroredVariables.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping attribute names to `SaveableObject` factories.
|
||||
"""
|
||||
def _saveable_factory(name=self._common_name):
|
||||
return _MirroredSaveable(self, self.primary, name)
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
def _should_act_as_resource_variable(self):
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
pass
|
||||
return super(TPUMirroredVariable, self)._as_graph_element() # pylint: disable=protected-access
|
||||
else:
|
||||
return None
|
||||
|
||||
# Needed to pass ResourceVariable checks.
|
||||
@property
|
||||
def op(self):
|
||||
return self.primary.op
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@property
|
||||
def _save_slice_info(self):
|
||||
return self.primary._save_slice_info
|
||||
|
||||
def _get_save_slice_info(self):
|
||||
return self.primary._get_save_slice_info()
|
||||
|
||||
def _set_save_slice_info(self, save_slice_info):
|
||||
return self.primary._set_save_slice_info(save_slice_info)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@property
|
||||
def _in_graph_mode(self):
|
||||
return self.primary._in_graph_mode # pylint: disable=protected-access
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
# pylint: disable=protected-access
|
||||
if _enclosing_tpu_context() is None:
|
||||
return self._get()._dense_var_to_tensor(dtype, name, as_ref)
|
||||
return super(TPUMirroredVariable, self)._dense_var_to_tensor(
|
||||
dtype, name, as_ref)
|
||||
# pylint: enable=protected-access
|
||||
if dtype is not None and dtype != self.dtype:
|
||||
return math_ops.cast(self.read_value(), dtype)
|
||||
@ -1366,40 +1107,6 @@ class TPUMirroredVariable(variables_lib.Variable):
|
||||
else:
|
||||
return self.read_value()
|
||||
|
||||
def is_initialized(self, name=None):
|
||||
"""Identifies if all the component variables are initialized.
|
||||
|
||||
Args:
|
||||
name: Name of the final `logical_and` op.
|
||||
|
||||
Returns:
|
||||
The op that evaluates to True or False depending on if all the
|
||||
component variables are initialized.
|
||||
"""
|
||||
# TODO(jhseu): Do we need TPU context implementation?
|
||||
|
||||
result = self.primary.is_initialized()
|
||||
# We iterate through the list of values except the last one to allow us to
|
||||
# name the final `logical_and` op the same name that is passed by the user
|
||||
# to the `is_initialized` op. For distributed variables, the
|
||||
# `is_initialized` op is a `logical_and` op.
|
||||
for v in self._values[1:-1]:
|
||||
result = math_ops.logical_and(result, v.is_initialized())
|
||||
result = math_ops.logical_and(result, self._values[-1].is_initialized(),
|
||||
name=name)
|
||||
return result
|
||||
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
|
||||
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(TPUMirroredVariable,
|
||||
_tensor_conversion_tpu_mirrored)
|
||||
ops.register_dense_tensor_like_type(TPUMirroredVariable)
|
||||
|
||||
|
||||
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
"""Class for defining how to restore a SyncOnReadVariable."""
|
||||
|
@ -996,7 +996,7 @@ def _var_key(var):
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Get the distributed variable if it exists.
|
||||
if getattr(var, "_distributed_container", None) is not None:
|
||||
if hasattr(var, "_distributed_container"):
|
||||
var = var._distributed_container()
|
||||
if var._in_graph_mode:
|
||||
return var._shared_name
|
||||
|
Loading…
x
Reference in New Issue
Block a user