Separate TPU variables into its own file

TPU code pulls in tensorflow.python.tpu.tpu which pulls a large number of
dependencies. It's easy to create cyclic dependencies.

PiperOrigin-RevId: 297060171
Change-Id: I8e09e7b8de9fc50e892502a480de8fffa378a1c5
This commit is contained in:
Ran Chen 2020-02-25 00:37:05 -08:00 committed by TensorFlower Gardener
parent 2e81bc66c5
commit db992d30bb
10 changed files with 339 additions and 276 deletions

View File

@ -2882,6 +2882,7 @@ tf_gen_op_wrapper_private_py(
name = "resource_variable_ops_gen",
visibility = [
"//tensorflow/compiler/tf2xla:internal",
"//tensorflow/python/distribute:__pkg__",
],
)

View File

@ -66,6 +66,7 @@ py_library(
":cross_device_utils",
":device_util",
":reduce_util",
":tpu_values",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:device_lib",
@ -531,6 +532,7 @@ py_library(
":input_lib",
":numpy_dataset",
":reduce_util",
":tpu_values",
":values",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/python:array_ops",
@ -612,18 +614,36 @@ py_library(
deps = [
":device_util",
":distribute_lib",
":reduce_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/eager:tape",
"//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
"//tensorflow/python/training/tracking:base",
"@six_archive//:six",
],
)
py_library(
name = "tpu_values",
srcs = ["tpu_values.py"],
deps = [
":values",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops_gen",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
"//tensorflow/python/tpu:tpu_lib",
],
)

View File

@ -27,6 +27,7 @@ from tensorflow.python.client import device_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -62,8 +63,8 @@ def validate_destinations(destinations):
if not isinstance(
destinations,
(value_lib.DistributedValues, ops.Tensor, value_lib.AggregatingVariable,
six.string_types, value_lib.TPUMirroredVariable)
) and not resource_variable_ops.is_resource_variable(destinations):
six.string_types, tpu_values.TPUMirroredVariable
)) and not resource_variable_ops.is_resource_variable(destinations):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, or a device string.")

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import threading
from tensorflow.python.framework import ops
@ -266,6 +267,20 @@ def experimental_set_strategy(strategy):
ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access
# ------------------------------------------------------------------------------
# Internal helpers.
@contextlib.contextmanager
def enter_or_assert_strategy(strategy):
if not has_strategy():
with strategy.scope():
yield
else:
_assert_strategy(strategy)
yield
# ------------------------------------------------------------------------------
# Defaults that are used when no tf.distribute.Strategy is explicitly created.
# We create them lazily in a function so that we can workaround the circular
@ -284,6 +299,17 @@ _default_replica_context_lock = threading.Lock()
_default_replica_mode_lock = threading.Lock()
def _assert_strategy(strategy):
if not has_strategy():
raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
(strategy,))
current_strategy = get_strategy()
if current_strategy is not strategy:
raise RuntimeError(
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
(current_strategy, strategy))
def _get_default_strategy():
if _defaults["strategy"] is None:
# Avoid race condition causing two defaults to be created

View File

@ -34,6 +34,7 @@ from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
from tensorflow.python.eager import context
@ -543,7 +544,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._logical_device_stack.append(logical_device_id)
try:
if values._enclosing_tpu_context() is None: # pylint: disable=protected-access
if tpu_values.enclosing_tpu_context() is None:
yield
else:
with ops.device(tpu.core(logical_device_id)):
@ -648,20 +649,20 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(**kwargs)
assert not isinstance(v, values.TPUMirroredVariable)
assert not isinstance(v, tpu_values.TPUMirroredVariable)
value_list.append(v)
return value_list
return values.create_mirrored_variable(self._container_strategy(),
_real_mirrored_creator,
values.TPUMirroredVariable,
values.TPUSyncOnReadVariable,
tpu_values.TPUMirroredVariable,
tpu_values.TPUSyncOnReadVariable,
**kwargs)
def _reduce_to(self, reduce_op, value, destinations):
if (isinstance(value, values.DistributedValues) or
tensor_util.is_tensor(value)
) and values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
) and tpu_values.enclosing_tpu_context() is not None:
if reduce_op == reduce_util.ReduceOp.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self._num_replicas_in_sync)
@ -701,9 +702,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return output
def _update(self, var, fn, args, kwargs, group):
assert isinstance(var, values.TPUVariableMixin) or isinstance(
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
var, resource_variable_ops.BaseResourceVariable)
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
if tpu_values.enclosing_tpu_context() is not None:
if group:
return fn(var, *args, **kwargs)
else:
@ -724,7 +725,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return values.update_regroup(self, updates, group)
def read_var(self, var):
assert isinstance(var, values.TPUVariableMixin) or isinstance(
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
var, resource_variable_ops.BaseResourceVariable)
return var.read_value()
@ -745,7 +746,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# since the `1` gets broadcast as an int32 but global_step is int64.
if isinstance(tensor, (float, int)):
return tensor
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
if tpu_values.enclosing_tpu_context() is not None:
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
result = tpu_ops.all_to_all(
broadcast_tensor,

View File

@ -0,0 +1,245 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing TPU distributed values.
Note that the tests are in values_test.py .
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.tpu import tpu
@contextlib.contextmanager
def _maybe_enter_graph(tensor):
# Note: might have an eager tensor but not be executing eagerly when
# building functions.
if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
ops.has_default_graph()):
yield
else:
with tensor.graph.as_default():
yield
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
del use_locking # Unused.
with _maybe_enter_graph(var.handle):
op = raw_assign_fn(
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
with ops.control_dependencies([op]):
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
return assign_fn
class TPUVariableMixin(object):
"""Mixin for TPU variables."""
def __init__(self, *args, **kwargs):
super(TPUVariableMixin, self).__init__(*args, **kwargs)
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
# correctly since in eager mode different variables can have the same name.
if ops.executing_eagerly_outside_functions():
self._handle_id = self._common_name + "_" + str(id(self._primary))
else:
self._handle_id = self._common_name
def __getattr__(self, name):
if enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).__getattr__(name)
else:
raise AttributeError(
"'{}' not accessible within a TPU context.".format(name))
def get(self):
if enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).get()
else:
raise NotImplementedError(
"`TPUVariableMixin.get()` is not supported within a TPU context.")
def _get_as_operand(self):
return self.read_value()
def _get_closest(self):
if enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._get_closest()
else:
return self._primary
def numpy(self):
if context.executing_eagerly():
return self.read_value().numpy()
else:
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
def _is_mirrored(self):
raise NotImplementedError(
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
@property
def handle(self):
# If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = enclosing_tpu_context()
if tpu_context is None:
return self._get_closest().handle
else:
return tpu_context.get_replicated_var_handle(self._handle_id,
self._values,
self._is_mirrored())
@property
def device(self):
return self.handle.device
def _read_variable_op(self):
if self.trainable:
tape.variable_accessed(self)
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
def read_value(self):
if enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).read_value()
else:
return self._read_variable_op()
def value(self):
if enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).value()
else:
return self._read_variable_op()
def _as_graph_element(self):
if enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
else:
return None
@property
def op(self):
return values.DistributedVarOp(self._primary.op.name,
self._primary.op.graph,
self._primary.op.traceback,
self._primary.op.type)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
# pylint: disable=protected-access
if enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._dense_var_to_tensor(
dtype=dtype, name=name, as_ref=as_ref)
# pylint: enable=protected-access
elif dtype is not None and dtype != self.dtype:
return math_ops.cast(self.read_value(), dtype)
else:
return self.handle if as_ref else self.read_value()
def enclosing_tpu_context():
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, tpu.TPUReplicateContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
"""Holds a map from replica to TPU variables whose values are kept in sync."""
def _assign_func(self, *args, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if (ds_context.in_cross_replica_context() and
(enclosing_tpu_context() is not None)):
f = kwargs.pop("f")
return self._distribute_strategy.extended.update(
self, f, args=args, kwargs=kwargs)
else:
return values.MirroredVariable._assign_func(self, *args, **kwargs)
def assign_sub(self, *args, **kwargs):
assign_sub_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
def assign_add(self, *args, **kwargs):
assign_add_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)
return self._assign_func(f=assign_add_fn, *args, **kwargs)
def assign(self, *args, **kwargs):
assign_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)
return self._assign_func(f=assign_fn, *args, **kwargs)
def _is_mirrored(self):
return True
class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
"""Holds a map from replica to variables whose values are reduced on save."""
def assign_sub(self, *args, **kwargs):
if enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
else:
return _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
**kwargs)
def assign_add(self, *args, **kwargs):
if enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
else:
return _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(self, *args,
**kwargs)
def assign(self, *args, **kwargs):
if enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign(self, *args, **kwargs)
else:
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
self, *args, **kwargs)
def _is_mirrored(self):
return False

View File

@ -19,12 +19,11 @@ from __future__ import division
from __future__ import print_function
import collections
import contextlib
import weakref
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
@ -34,11 +33,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.tpu import tpu
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
@ -48,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
def _get_current_replica_id_as_int():
"""Returns the current replica ID as an integer, or `None`."""
replica_context = distribution_strategy_context.get_replica_context()
replica_context = ds_context.get_replica_context()
if replica_context:
replica_id = replica_context.replica_id_in_sync_group
if not isinstance(replica_id, int):
@ -362,7 +359,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
return self._value_specs
def _to_components(self, value):
replica_context = distribution_strategy_context.get_replica_context()
replica_context = ds_context.get_replica_context()
if replica_context is not None and replica_context.num_replicas_in_sync > 1:
raise ValueError(
"Flattening a PerReplica to components is not supported in replica "
@ -405,27 +402,6 @@ def _assign_sub_on_device(device, variable, tensor):
return variable.assign_sub(tensor)
def _assert_strategy(strategy):
if not distribution_strategy_context.has_strategy():
raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
(strategy,))
current_strategy = distribution_strategy_context.get_strategy()
if current_strategy is not strategy:
raise RuntimeError(
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
(current_strategy, strategy))
@contextlib.contextmanager
def _enter_or_assert_strategy(strategy):
if not distribution_strategy_context.has_strategy():
with strategy.scope():
yield
else:
_assert_strategy(strategy)
yield
DistributedVarOp = collections.namedtuple(
"DistributedVarOp", ["name", "graph", "traceback", "type"])
@ -578,7 +554,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
# We want cross-replica code that does some var.op.X calls
# to work (even if the current device isn't in self._devices), but
# other uses of var.op in a cross-replica context to fail.
if distribution_strategy_context.in_cross_replica_context():
if ds_context.in_cross_replica_context():
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
self._primary.op.traceback, self._primary.op.type)
return self._get().op
@ -588,7 +564,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
return self._primary._in_graph_mode # pylint: disable=protected-access
def read_value(self):
with _enter_or_assert_strategy(self._distribute_strategy):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return array_ops.identity(self._get())
def value(self):
@ -602,135 +578,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
ops.register_dense_tensor_like_type(DistributedVariable)
@contextlib.contextmanager
def _maybe_enter_graph(tensor):
# Note: might have an eager tensor but not be executing eagerly when
# building functions.
if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
ops.has_default_graph()):
yield
else:
with tensor.graph.as_default():
yield
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
del use_locking # Unused.
with _maybe_enter_graph(var.handle):
op = raw_assign_fn(
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
with ops.control_dependencies([op]):
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
return assign_fn
class TPUVariableMixin(object):
"""Mixin for TPU variables."""
def __init__(self, *args, **kwargs):
super(TPUVariableMixin, self).__init__(*args, **kwargs)
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
# correctly since in eager mode different variables can have the same name.
if ops.executing_eagerly_outside_functions():
self._handle_id = self._common_name + "_" + str(id(self._primary))
else:
self._handle_id = self._common_name
def __getattr__(self, name):
if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).__getattr__(name)
else:
raise AttributeError(
"'{}' not accessible within a TPU context.".format(name))
def get(self):
if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).get()
else:
raise NotImplementedError(
"`TPUVariableMixin.get()` is not supported within a TPU context.")
def _get_as_operand(self):
return self.read_value()
def _get_closest(self):
if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._get_closest()
else:
return self._primary
def numpy(self):
if context.executing_eagerly():
return self.read_value().numpy()
else:
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
def _is_mirrored(self):
raise NotImplementedError(
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
@property
def handle(self):
# If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = _enclosing_tpu_context()
if tpu_context is None:
return self._get_closest().handle
else:
return tpu_context.get_replicated_var_handle(
self._handle_id, self._values, self._is_mirrored())
@property
def device(self):
return self.handle.device
def _read_variable_op(self):
if self.trainable:
tape.variable_accessed(self)
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
def read_value(self):
if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).read_value()
else:
return self._read_variable_op()
def value(self):
if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).value()
else:
return self._read_variable_op()
def _as_graph_element(self):
if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
else:
return None
@property
def op(self):
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
self._primary.op.traceback, self._primary.op.type)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
# pylint: disable=protected-access
if _enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._dense_var_to_tensor(
dtype=dtype, name=name, as_ref=as_ref)
# pylint: enable=protected-access
elif dtype is not None and dtype != self.dtype:
return math_ops.cast(self.read_value(), dtype)
else:
return self.handle if as_ref else self.read_value()
def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended:
@ -888,9 +735,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
# update_non_slot() function (like OptimizerV2._finish), which can
# update several non-slot variables in one call.
def _assign_func(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
f = kwargs.pop("f")
if distribution_strategy_context.in_cross_replica_context():
if ds_context.in_cross_replica_context():
update_replica_id = distribute_lib.get_update_replica_id()
if update_replica_id is not None:
# We are calling an assign function on the mirrored variable in an
@ -933,7 +780,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
return strategy.extended.update(
self, f, args=(v,) + other_args, kwargs=other_kwargs)
return distribution_strategy_context.get_replica_context().merge_call(
return ds_context.get_replica_context().merge_call(
merge_fn, args=args, kwargs=kwargs)
def assign_sub(self, *args, **kwargs):
@ -1003,60 +850,11 @@ ops.register_tensor_conversion_function(Mirrored,
_tensor_conversion_mirrored_val)
def _enclosing_tpu_context():
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, tpu.TPUReplicateContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None
def is_distributed_variable(v):
"""Determine if a variable is ds variable or TPU mirrored variable."""
return isinstance(v, DistributedVariable)
class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
"""Holds a map from replica to TPU variables whose values are kept in sync."""
def _assign_func(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
if (distribution_strategy_context.in_cross_replica_context() and
(_enclosing_tpu_context() is not None)):
f = kwargs.pop("f")
return self._distribute_strategy.extended.update(
self, f, args=args, kwargs=kwargs)
else:
return MirroredVariable._assign_func(self, *args, **kwargs)
def assign_sub(self, *args, **kwargs):
assign_sub_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
def assign_add(self, *args, **kwargs):
assign_add_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)
return self._assign_func(f=assign_add_fn, *args, **kwargs)
def assign(self, *args, **kwargs):
assign_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)
return self._assign_func(f=assign_fn, *args, **kwargs)
def _is_mirrored(self):
return True
class _SyncOnReadSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a SyncOnReadVariable."""
@ -1094,7 +892,7 @@ class _SyncOnReadSaveable(saveable_object.SaveableObject):
def _assert_replica_context(strategy):
replica_context = distribution_strategy_context.get_replica_context()
replica_context = ds_context.get_replica_context()
if not replica_context:
raise RuntimeError(
"Replica-local variables may only be assigned in a replica context.")
@ -1111,8 +909,8 @@ class SyncOnReadVariable(DistributedVariable):
self._aggregation = aggregation
def assign_sub(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
if distribution_strategy_context.in_cross_replica_context():
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if self._aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_sub` in "
@ -1126,8 +924,8 @@ class SyncOnReadVariable(DistributedVariable):
return self._get().assign_sub(*args, **kwargs)
def assign_add(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
if distribution_strategy_context.in_cross_replica_context():
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
if self._aggregation == vs.VariableAggregation.SUM:
raise ValueError(
"SyncOnReadVariable does not support `assign_add` in "
@ -1141,8 +939,8 @@ class SyncOnReadVariable(DistributedVariable):
return self._get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
if distribution_strategy_context.in_cross_replica_context():
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
@ -1155,8 +953,8 @@ class SyncOnReadVariable(DistributedVariable):
return self._get().assign(*args, **kwargs)
def value(self):
with _enter_or_assert_strategy(self._distribute_strategy):
if distribution_strategy_context.in_cross_replica_context():
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
return self._get_cross_replica()
else:
# _get_closest() returns a Variable.
@ -1177,7 +975,7 @@ class SyncOnReadVariable(DistributedVariable):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return self._primary
with _enter_or_assert_strategy(self._distribute_strategy):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return self._distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
self,
@ -1185,8 +983,8 @@ class SyncOnReadVariable(DistributedVariable):
def _as_graph_element(self):
# pylint: disable=protected-access
with _enter_or_assert_strategy(self._distribute_strategy):
if distribution_strategy_context.in_cross_replica_context():
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if ds_context.in_cross_replica_context():
return ops.convert_to_tensor(self._get_cross_replica())
return self._get()._as_graph_element()
@ -1207,7 +1005,7 @@ class SyncOnReadVariable(DistributedVariable):
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
with _enter_or_assert_strategy(self._distribute_strategy):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return ops.convert_to_tensor(
self._get(), dtype=dtype, name=name, as_ref=as_ref)
@ -1222,36 +1020,6 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
_tensor_conversion_sync_on_read)
class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
"""Holds a map from replica to variables whose values are reduced on save."""
def assign_sub(self, *args, **kwargs):
if _enclosing_tpu_context() is None:
return SyncOnReadVariable.assign_sub(self, *args, **kwargs)
else:
return _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
**kwargs)
def assign_add(self, *args, **kwargs):
if _enclosing_tpu_context() is None:
return SyncOnReadVariable.assign_add(self, *args, **kwargs)
else:
return _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(self, *args,
**kwargs)
def assign(self, *args, **kwargs):
if _enclosing_tpu_context() is None:
return SyncOnReadVariable.assign(self, *args, **kwargs)
else:
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
self, *args, **kwargs)
def _is_mirrored(self):
return False
def regroup(values, wrap_class=PerReplica):
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
v0 = values[0]
@ -1444,9 +1212,9 @@ class AggregatingVariable(variables_lib.Variable):
return getattr(self._v, name)
def _assign_func(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
f = kwargs.pop("f")
if distribution_strategy_context.in_cross_replica_context():
if ds_context.in_cross_replica_context():
if distribute_lib.get_update_replica_id() is not None:
# We are calling an assign function in an update context.
return f(self._v, *args, **kwargs)
@ -1456,7 +1224,7 @@ class AggregatingVariable(variables_lib.Variable):
return self._distribute_strategy.extended.update(
self, f, args=args, kwargs=kwargs)
else:
replica_context = distribution_strategy_context.get_replica_context()
replica_context = ds_context.get_replica_context()
assert replica_context
# We are calling an assign function in replica context.
# We reduce the value we want to assign/add/sub. More details about how

View File

@ -29,6 +29,7 @@ from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
@ -824,7 +825,7 @@ def _make_replica_local(method, strategy=None):
name=n, initializer=init, use_resource=True))
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
var_cls = values.TPUSyncOnReadVariable
var_cls = tpu_values.TPUSyncOnReadVariable
else:
var_cls = values.SyncOnReadVariable
replica_local = var_cls(strategy, v, method)

View File

@ -28,6 +28,7 @@ tf_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:tpu_values",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -26,6 +26,7 @@ from absl.testing import parameterized
import six
from tensorflow.python import tf2
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as distributed_values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -249,10 +250,8 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
def test_supports_distributed_variables(self):
mirrored = distributed_values.MirroredVariable(
None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
tpu = distributed_values.TPUMirroredVariable(
strategy=None,
values=[variables.Variable(42.)],
aggregation=None)
tpu = tpu_values.TPUMirroredVariable(
strategy=None, values=[variables.Variable(42.)], aggregation=None)
aggregating = distributed_values.AggregatingVariable(
strategy=None, v=variables.Variable(1.), aggregation=None)