Make TPUMirroredVariable a subclass of MirroredVariable.

PiperOrigin-RevId: 252370482
This commit is contained in:
Chris Jones 2019-06-10 02:34:36 -07:00 committed by TensorFlower Gardener
parent f070a4b01c
commit 8da25cbfe7
3 changed files with 119 additions and 410 deletions

View File

@ -485,16 +485,16 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return cross_device_ops_lib.reduce_non_distributed_value( return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, self._device_map, value, destinations) reduce_op, self._device_map, value, destinations)
devices = cross_device_ops_lib.get_devices_from(destinations) # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
if len(devices) != 1:
raise ValueError("Multiple devices are not supported for TPUStrategy")
# Always performs the reduction on the TPU host. # Always performs the reduction on the TPU host.
with ops.device(self._host_device): with ops.device(self._host_device):
output = math_ops.add_n(value.values) output = math_ops.add_n(value.values)
if reduce_op == reduce_util.ReduceOp.MEAN: if reduce_op == reduce_util.ReduceOp.MEAN:
output *= (1. / len(value.values)) output *= (1. / len(value.values))
devices = cross_device_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
# If necessary, copy to requested destination. # If necessary, copy to requested destination.
dest_canonical = device_util.canonicalize(devices[0]) dest_canonical = device_util.canonicalize(devices[0])
host_canonical = device_util.canonicalize(self._host_device) host_canonical = device_util.canonicalize(self._host_device)
@ -502,6 +502,8 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
if dest_canonical != host_canonical: if dest_canonical != host_canonical:
with ops.device(dest_canonical): with ops.device(dest_canonical):
output = array_ops.identity(output) output = array_ops.identity(output)
else:
output = cross_device_ops_lib.simple_broadcast(output, destinations)
return output return output

View File

@ -388,63 +388,73 @@ class DistributedDelegate(DistributedValues):
# __getattr__ and @property. See b/120402273. # __getattr__ and @property. See b/120402273.
return getattr(self.get(), name) return getattr(self.get(), name)
def _get_as_operand(self):
"""Returns the value for operations for the current device.
Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
value type within a replica context. They can, however, return a value that
can be used by the operations below.
"""
return self.get()
# pylint: disable=multiple-statements # pylint: disable=multiple-statements
def __add__(self, o): return self.get() + o def __add__(self, o): return self._get_as_operand() + o
def __radd__(self, o): return o + self.get() def __radd__(self, o): return o + self._get_as_operand()
def __sub__(self, o): return self.get() - o def __sub__(self, o): return self._get_as_operand() - o
def __rsub__(self, o): return o - self.get() def __rsub__(self, o): return o - self._get_as_operand()
def __mul__(self, o): return self.get() * o def __mul__(self, o): return self._get_as_operand() * o
def __rmul__(self, o): return o * self.get() def __rmul__(self, o): return o * self._get_as_operand()
def __truediv__(self, o): return self.get() / o def __truediv__(self, o): return self._get_as_operand() / o
def __rtruediv__(self, o): return o / self.get() def __rtruediv__(self, o): return o / self._get_as_operand()
def __floordiv__(self, o): def __floordiv__(self, o):
return self.get() // o return self._get_as_operand() // o
def __rfloordiv__(self, o): return o // self.get() def __rfloordiv__(self, o): return o // self._get_as_operand()
def __mod__(self, o): return self.get() % o def __mod__(self, o): return self._get_as_operand() % o
def __rmod__(self, o): return o % self.get() def __rmod__(self, o): return o % self._get_as_operand()
def __lt__(self, o): return self.get() < o def __lt__(self, o): return self._get_as_operand() < o
def __le__(self, o): return self.get() <= o def __le__(self, o): return self._get_as_operand() <= o
def __gt__(self, o): return self.get() > o def __gt__(self, o): return self._get_as_operand() > o
def __ge__(self, o): return self.get() >= o def __ge__(self, o): return self._get_as_operand() >= o
def __and__(self, o): return self.get() & o def __and__(self, o): return self._get_as_operand() & o
def __rand__(self, o): return o & self.get() def __rand__(self, o): return o & self._get_as_operand()
def __or__(self, o): return self.get() | o def __or__(self, o): return self._get_as_operand() | o
def __ror__(self, o): return o | self.get() def __ror__(self, o): return o | self._get_as_operand()
def __xor__(self, o): return self.get() ^ o def __xor__(self, o): return self._get_as_operand() ^ o
def __rxor__(self, o): return o ^ self.get() def __rxor__(self, o): return o ^ self._get_as_operand()
def __getitem__(self, o): return self.get()[o] def __getitem__(self, o): return self._get_as_operand()[o]
def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo) def __pow__(self, o, modulo=None):
def __rpow__(self, o): return pow(o, self.get()) return pow(self._get_as_operand(), o, modulo)
def __invert__(self): return ~self.get() def __rpow__(self, o): return pow(o, self._get_as_operand())
def __neg__(self): return -self.get() def __invert__(self): return ~self._get_as_operand()
def __abs__(self): return abs(self.get()) def __neg__(self): return -self._get_as_operand()
def __abs__(self): return abs(self._get_as_operand())
def __div__(self, o): def __div__(self, o):
try: try:
return self.get().__div__(o) return self._get_as_operand().__div__(o)
except AttributeError: except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented # See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented return NotImplemented
def __rdiv__(self, o): def __rdiv__(self, o):
try: try:
return self.get().__rdiv__(o) return self._get_as_operand().__rdiv__(o)
except AttributeError: except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented # See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented return NotImplemented
def __matmul__(self, o): def __matmul__(self, o):
try: try:
return self.get().__matmul__(o) return self._get_as_operand().__matmul__(o)
except AttributeError: except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented # See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented return NotImplemented
def __rmatmul__(self, o): def __rmatmul__(self, o):
try: try:
return self.get().__rmatmul__(o) return self._get_as_operand().__rmatmul__(o)
except AttributeError: except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented # See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented return NotImplemented
@ -920,15 +930,19 @@ class MirroredVariable(DistributedVariable, Mirrored):
return _MirroredSaveable(self, self.primary, name) return _MirroredSaveable(self, self.primary, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
# Register a conversion function which reads the value of the variable, """Converts a variable to a tensor."""
# allowing instances of the class to be used as tensors.
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
# Try to avoid assignments to and other mutations of MirroredVariable # Try to avoid assignments to and other mutations of MirroredVariable
# state except through a DistributionStrategy.extended.update() call. # state except through a DistributionStrategy.extended.update() call.
assert not as_ref assert not as_ref
return ops.internal_convert_to_tensor( return ops.internal_convert_to_tensor(
var.get(), dtype=dtype, name=name, as_ref=as_ref) self.get(), dtype=dtype, name=name, as_ref=as_ref)
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
ops.register_tensor_conversion_function(MirroredVariable, ops.register_tensor_conversion_function(MirroredVariable,
@ -947,33 +961,17 @@ def _enclosing_tpu_context():
def is_distributed_variable(v): def is_distributed_variable(v):
"""Determine if a variable is ds variable or TPU mirrored variable.""" """Determine if a variable is ds variable or TPU mirrored variable."""
return (isinstance(v, DistributedVariable) return isinstance(v, DistributedVariable)
or isinstance(v, TPUMirroredVariable))
# TODO(jhseu): Deduplicate code. We copy code because we don't want to class TPUMirroredVariable(MirroredVariable):
# inherit from DistributedDelegate. DistributedDelegate will not work in a
# tpu.replicate() because it assumes that you're in a device context where you
# can operate on a single version of the variable, but a tpu.replicate()
# operates on all variables and is replicated during a rewrite pass.
class TPUMirroredVariable(variables_lib.Variable):
"""Holds a map from device to TPU variables whose values are kept in sync.""" """Holds a map from device to TPU variables whose values are kept in sync."""
def __init__( def __init__(
self, strategy, device_map, values, aggregation, logical_device=None): self, strategy, device_map, values, aggregation, logical_device=None):
assert isinstance(device_map, DeviceMap) super(TPUMirroredVariable, self).__init__(
self._distribute_strategy = strategy strategy=strategy, device_map=device_map, values=values,
self._device_map = device_map aggregation=aggregation, logical_device=logical_device)
self._values = tuple(values)
if logical_device is None:
logical_device = device_map.logical_device_from_values(self._values)
self._logical_device = logical_device
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in self._values:
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._common_name = self.primary.name.split(":")[0]
# Handle id is needed for get_replicated_var_handle to cache the variables # Handle id is needed for get_replicated_var_handle to cache the variables
# correctly since in eager mode different variables can have the same name. # correctly since in eager mode different variables can have the same name.
@ -982,28 +980,21 @@ class TPUMirroredVariable(variables_lib.Variable):
else: else:
self._handle_id = self._common_name self._handle_id = self._common_name
self._aggregation = aggregation def get(self, device=None):
# Needed for GradientTape if (_enclosing_tpu_context() is None) or (device is not None):
self._trainable = self.primary.trainable return super(TPUMirroredVariable, self).get(device=device)
# Typically like `DistributedVariable`, a `TPUMirroredVariable`'s
# initializer is composed of the initializers of the components variables.
# However, in some cases, such as when restoring from a checkpoint, we may
# set the _initializer_op property on the entire `TPUMirroredVariable`.
self._initializer_op = None
def _get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
if device is None:
replica_context = distribution_strategy_context.get_replica_context()
if replica_context:
return self._device_map.select_for_current_replica(
self._values, replica_context)
else: else:
device = distribute_lib.get_update_device() raise NotImplementedError(
if device is None: "`TPUMirroredVariable.get()` is not supported within a TPU context.")
return self._get_cross_replica()
device = device_util.canonicalize(device) def _get_as_operand(self):
return self._device_map.select_for_device(self._values, device) return self.read_value()
def _get_closest(self):
if _enclosing_tpu_context() is None:
return super(TPUMirroredVariable, self)._get_closest()
else:
return self.primary
def numpy(self): def numpy(self):
if context.executing_eagerly(): if context.executing_eagerly():
@ -1011,174 +1002,20 @@ class TPUMirroredVariable(variables_lib.Variable):
raise NotImplementedError( raise NotImplementedError(
"numpy() is only available when eager execution is enabled.") "numpy() is only available when eager execution is enabled.")
def initialized_value(self):
return self.primary.initialized_value()
@property
def initial_value(self):
return self.primary.initial_value
@property
def primary(self):
"""Returns a representative component."""
return self._values[0]
@property
def devices(self):
return self._device_map.logical_to_actual_devices(self._logical_device)
@property
def logical_device(self):
return self._logical_device
@property
def device_map(self):
return self._device_map
# TODO(josh11b): Replace experimental_local_results with this?
@property
def values(self):
return self._values
@property
def distribute_strategy(self):
return self._distribute_strategy
# pylint: disable=multiple-statements
def __add__(self, o): return self.read_value() + o
def __radd__(self, o): return o + self.read_value()
def __sub__(self, o): return self.read_value() - o
def __rsub__(self, o): return o - self.read_value()
def __mul__(self, o): return self.read_value() * o
def __rmul__(self, o): return o * self.read_value()
def __truediv__(self, o): return self.read_value() / o
def __rtruediv__(self, o): return o / self.read_value()
def __floordiv__(self, o): return self.read_value() // o
def __rfloordiv__(self, o): return o // self.read_value()
def __mod__(self, o): return self.read_value() % o
def __rmod__(self, o): return o % self.read_value()
def __lt__(self, o): return self.read_value() < o
def __le__(self, o): return self.read_value() <= o
def __gt__(self, o): return self.read_value() > o
def __ge__(self, o): return self.read_value() >= o
def __and__(self, o): return self.read_value() & o
def __rand__(self, o): return o & self.read_value()
def __or__(self, o): return self.read_value() | o
def __ror__(self, o): return o | self.read_value()
def __xor__(self, o): return self.read_value() ^ o
def __rxor__(self, o): return o ^ self.read_value()
def __getitem__(self, o): return self.read_value()[o]
def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
def __rpow__(self, o): return pow(o, self.read_value())
def __invert__(self): return ~self.read_value()
def __neg__(self): return -self.read_value()
def __abs__(self): return abs(self.read_value())
def __div__(self, o):
try:
return self.read_value().__div__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __rdiv__(self, o):
try:
return self.read_value().__rdiv__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __matmul__(self, o):
try:
return self.read_value().__matmul__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __rmatmul__(self, o):
try:
return self.read_value().__rmatmul__(o)
except AttributeError:
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __str__(self):
devices = self.devices
debug_str = ",\n".join(" %d %s: %s" % (i, devices[i], self._values[i])
for i in range(len(devices)))
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
def __repr__(self):
devices = self.devices
debug_repr = ",\n".join(" %d %s: %r" % (i, devices[i], self._values[i])
for i in range(len(devices)))
return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
@property @property
def handle(self): def handle(self):
# If we're in a tpu.rewrite(), return the replicated handle. # If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = _enclosing_tpu_context() tpu_context = _enclosing_tpu_context()
if tpu_context is not None: if tpu_context is None:
return self._get_closest().handle
else:
return tpu_context.get_replicated_var_handle( return tpu_context.get_replicated_var_handle(
self._handle_id, self._values) self._handle_id, self._values)
device = distribute_lib.get_update_device()
if device is None:
return self.primary.handle
return self._get(device=device).handle
@property @property
def device(self): def device(self):
return self.handle.device return self.handle.device
def eval(self, session=None):
return self.primary.eval(session)
# The arguments to update() are automatically unwrapped so the update()
# function would normally see regular variables, not MirroredVariables.
# However, the update function can still operate on wrapped MirroredVariables
# through object members, captured arguments, etc. This is more likely in an
# update_non_slot() function (like OptimizerV2._finish), which can
# update several non-slot variables in one call.
def _assign_func(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
f = kwargs.pop("f")
if distribution_strategy_context.in_cross_replica_context():
if _enclosing_tpu_context() is not None:
return self._distribute_strategy.extended.update(
self, f, args=args, kwargs=kwargs)
update_device = distribute_lib.get_update_device()
# We are calling update on the mirrored variable in cross replica
# context.
if update_device is not None:
# We are calling an assign function on the mirrored variable in cross
# replica context.
v = self._get(device=update_device)
return f(v, *args, **kwargs)
return self._distribute_strategy.extended.update(
self, f, args=args, kwargs=kwargs)
else:
_assert_replica_context(self._distribute_strategy)
# We are calling an assign function on the mirrored variable in replica
# context.
# We reduce the value we want to assign/add/sub. More details about how
# we handle the different use cases can be found in the _reduce method.
# We call the function on each of the mirrored variables with the
# reduced value.
if self._aggregation == vs.VariableAggregation.NONE:
raise ValueError(_aggregation_error_msg.format(
variable_type="TPUMirroredVariable"))
def merge_fn(strategy, value, *other_args, **other_kwargs):
v = _apply_aggregation(strategy, value, self._aggregation, self)
return strategy.extended.update(
self, f, args=(v,) + other_args, kwargs=other_kwargs)
return distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=args, kwargs=kwargs)
@contextlib.contextmanager @contextlib.contextmanager
def _handle_graph(self, handle): def _handle_graph(self, handle):
# Note: might have an eager tensor but not be executing eagerly when # Note: might have an eager tensor but not be executing eagerly when
@ -1190,10 +1027,6 @@ class TPUMirroredVariable(variables_lib.Variable):
with handle.graph.as_default(): with handle.graph.as_default():
yield yield
@property
def trainable(self):
return self._trainable
def _read_variable_op(self, parent_op=None): def _read_variable_op(self, parent_op=None):
if self.trainable: if self.trainable:
tape.variable_accessed(self) tape.variable_accessed(self)
@ -1208,156 +1041,64 @@ class TPUMirroredVariable(variables_lib.Variable):
def read_value(self): def read_value(self):
return self._read_variable_op() return self._read_variable_op()
def assign_sub(self, *args, **kwargs): def _assign_func(self, *args, **kwargs):
def assign_sub_fn(var, delta, *ar, **kw): with _enter_or_assert_strategy(self._distribute_strategy):
del ar if (distribution_strategy_context.in_cross_replica_context()
name = kw.pop("name", None) and (_enclosing_tpu_context() is not None)):
read_value = kw.pop("read_value", True) f = kwargs.pop("f")
with self._handle_graph(var.handle): return self._distribute_strategy.extended.update(
op = gen_resource_variable_ops.assign_sub_variable_op( self, f, args=args, kwargs=kwargs)
var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), else:
name=name) return super(TPUMirroredVariable, self)._assign_func(*args, **kwargs)
if read_value:
return self._read_variable_op(parent_op=op)
return op
def _make_raw_assign_fn(self, raw_assign_fn):
def assign_fn(var, value, *args, **kwargs):
del args
name = kwargs.pop("name", None)
read_value = kwargs.pop("read_value", True)
with self._handle_graph(var.handle):
op = raw_assign_fn(
var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
name=name)
return self._read_variable_op(parent_op=op) if read_value else op
return assign_fn
def assign_sub(self, *args, **kwargs):
assign_sub_fn = self._make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)
return self._assign_func(f=assign_sub_fn, *args, **kwargs) return self._assign_func(f=assign_sub_fn, *args, **kwargs)
def assign_add(self, *args, **kwargs): def assign_add(self, *args, **kwargs):
def assign_add_fn(var, delta, *ar, **kw): assign_add_fn = self._make_raw_assign_fn(
del ar gen_resource_variable_ops.assign_add_variable_op)
name = kw.pop("name", None)
read_value = kw.pop("read_value", True)
with self._handle_graph(var.handle):
op = gen_resource_variable_ops.assign_add_variable_op(
var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
name=name)
if read_value:
return self._read_variable_op(parent_op=op)
return op
return self._assign_func(f=assign_add_fn, *args, **kwargs) return self._assign_func(f=assign_add_fn, *args, **kwargs)
def assign(self, *args, **kwargs): def assign(self, *args, **kwargs):
def assign_fn(var, value, *ar, **kw): assign_fn = self._make_raw_assign_fn(
del ar gen_resource_variable_ops.assign_variable_op)
name = kw.pop("name", None)
read_value = kw.pop("read_value", True)
with self._handle_graph(var.handle):
op = gen_resource_variable_ops.assign_variable_op(
var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
name=name)
if read_value:
return self._read_variable_op(parent_op=op)
return op
return self._assign_func(f=assign_fn, *args, **kwargs) return self._assign_func(f=assign_fn, *args, **kwargs)
@property
def aggregation(self):
return self._aggregation
@property @property
def constraint(self): def constraint(self):
return self.primary.constraint return self.primary.constraint
@property
def initializer(self):
if self._initializer_op:
init_op = self._initializer_op
else:
init_op = control_flow_ops.group(tuple(
v.initializer for v in self._values))
return init_op
@property
def graph(self):
return self.primary.graph
@property
def _shared_name(self):
return self._common_name
@property
def _unique_id(self):
return self.primary._unique_id # pylint: disable=protected-access
@property
def name(self):
return self.primary.name
@property
def dtype(self):
return self.primary.dtype
@property
def shape(self):
return self.primary.shape
def get_shape(self):
return self.primary.get_shape()
def to_proto(self, export_scope=None):
return self.primary.to_proto(export_scope=export_scope)
def _get_cross_replica(self):
device = device_util.canonicalize(device_util.current())
replica = self._device_map.replica_for_device(device)
if replica is None:
return self.primary
return self._values[replica]
def _as_graph_element(self): def _as_graph_element(self):
# pylint: disable=protected-access
if _enclosing_tpu_context() is None: if _enclosing_tpu_context() is None:
if distribution_strategy_context.in_cross_replica_context(): return super(TPUMirroredVariable, self)._as_graph_element() # pylint: disable=protected-access
return self.primary._as_graph_element() else:
return self._get()._as_graph_element()
return None return None
def _gather_saveables_for_checkpoint(self):
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
MirroredVariables.
Returns:
A dictionary mapping attribute names to `SaveableObject` factories.
"""
def _saveable_factory(name=self._common_name):
return _MirroredSaveable(self, self.primary, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
# Needed to pass ResourceVariable checks. # Needed to pass ResourceVariable checks.
@property @property
def op(self): def op(self):
return self.primary.op return self.primary.op
# pylint: disable=protected-access
@property
def _save_slice_info(self):
return self.primary._save_slice_info
def _get_save_slice_info(self):
return self.primary._get_save_slice_info()
def _set_save_slice_info(self, save_slice_info):
return self.primary._set_save_slice_info(save_slice_info)
# pylint: enable=protected-access
@property
def _in_graph_mode(self):
return self.primary._in_graph_mode # pylint: disable=protected-access
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor.""" """Converts a variable to a tensor."""
# pylint: disable=protected-access # pylint: disable=protected-access
if _enclosing_tpu_context() is None: if _enclosing_tpu_context() is None:
return self._get()._dense_var_to_tensor(dtype, name, as_ref) return super(TPUMirroredVariable, self)._dense_var_to_tensor(
dtype, name, as_ref)
# pylint: enable=protected-access # pylint: enable=protected-access
if dtype is not None and dtype != self.dtype: if dtype is not None and dtype != self.dtype:
return math_ops.cast(self.read_value(), dtype) return math_ops.cast(self.read_value(), dtype)
@ -1366,40 +1107,6 @@ class TPUMirroredVariable(variables_lib.Variable):
else: else:
return self.read_value() return self.read_value()
def is_initialized(self, name=None):
"""Identifies if all the component variables are initialized.
Args:
name: Name of the final `logical_and` op.
Returns:
The op that evaluates to True or False depending on if all the
component variables are initialized.
"""
# TODO(jhseu): Do we need TPU context implementation?
result = self.primary.is_initialized()
# We iterate through the list of values except the last one to allow us to
# name the final `logical_and` op the same name that is passed by the user
# to the `is_initialized` op. For distributed variables, the
# `is_initialized` op is a `logical_and` op.
for v in self._values[1:-1]:
result = math_ops.logical_and(result, v.is_initialized())
result = math_ops.logical_and(result, self._values[-1].is_initialized(),
name=name)
return result
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
ops.register_tensor_conversion_function(TPUMirroredVariable,
_tensor_conversion_tpu_mirrored)
ops.register_dense_tensor_like_type(TPUMirroredVariable)
class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject): class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a SyncOnReadVariable.""" """Class for defining how to restore a SyncOnReadVariable."""

View File

@ -996,7 +996,7 @@ def _var_key(var):
# pylint: disable=protected-access # pylint: disable=protected-access
# Get the distributed variable if it exists. # Get the distributed variable if it exists.
if getattr(var, "_distributed_container", None) is not None: if hasattr(var, "_distributed_container"):
var = var._distributed_container() var = var._distributed_container()
if var._in_graph_mode: if var._in_graph_mode:
return var._shared_name return var._shared_name