Separate TPU variables into its own file
TPU code pulls in tensorflow.python.tpu.tpu which pulls a large number of dependencies. It's easy to create cyclic dependencies. PiperOrigin-RevId: 297060171 Change-Id: I8e09e7b8de9fc50e892502a480de8fffa378a1c5
This commit is contained in:
parent
2e81bc66c5
commit
db992d30bb
tensorflow/python
@ -2882,6 +2882,7 @@ tf_gen_op_wrapper_private_py(
|
||||
name = "resource_variable_ops_gen",
|
||||
visibility = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
"//tensorflow/python/distribute:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -66,6 +66,7 @@ py_library(
|
||||
":cross_device_utils",
|
||||
":device_util",
|
||||
":reduce_util",
|
||||
":tpu_values",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:device_lib",
|
||||
@ -531,6 +532,7 @@ py_library(
|
||||
":input_lib",
|
||||
":numpy_dataset",
|
||||
":reduce_util",
|
||||
":tpu_values",
|
||||
":values",
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -612,18 +614,36 @@ py_library(
|
||||
deps = [
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":reduce_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
"//tensorflow/python/eager:tape",
|
||||
"//tensorflow/python/training/saving:saveable_object",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_values",
|
||||
srcs = ["tpu_values.py"],
|
||||
deps = [
|
||||
":values",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops_gen",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:tape",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -62,8 +63,8 @@ def validate_destinations(destinations):
|
||||
if not isinstance(
|
||||
destinations,
|
||||
(value_lib.DistributedValues, ops.Tensor, value_lib.AggregatingVariable,
|
||||
six.string_types, value_lib.TPUMirroredVariable)
|
||||
) and not resource_variable_ops.is_resource_variable(destinations):
|
||||
six.string_types, tpu_values.TPUMirroredVariable
|
||||
)) and not resource_variable_ops.is_resource_variable(destinations):
|
||||
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
||||
" a tf.Variable object, or a device string.")
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
@ -266,6 +267,20 @@ def experimental_set_strategy(strategy):
|
||||
ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Internal helpers.
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enter_or_assert_strategy(strategy):
|
||||
if not has_strategy():
|
||||
with strategy.scope():
|
||||
yield
|
||||
else:
|
||||
_assert_strategy(strategy)
|
||||
yield
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Defaults that are used when no tf.distribute.Strategy is explicitly created.
|
||||
# We create them lazily in a function so that we can workaround the circular
|
||||
@ -284,6 +299,17 @@ _default_replica_context_lock = threading.Lock()
|
||||
_default_replica_mode_lock = threading.Lock()
|
||||
|
||||
|
||||
def _assert_strategy(strategy):
|
||||
if not has_strategy():
|
||||
raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
|
||||
(strategy,))
|
||||
current_strategy = get_strategy()
|
||||
if current_strategy is not strategy:
|
||||
raise RuntimeError(
|
||||
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
||||
(current_strategy, strategy))
|
||||
|
||||
|
||||
def _get_default_strategy():
|
||||
if _defaults["strategy"] is None:
|
||||
# Avoid race condition causing two defaults to be created
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import numpy_dataset
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
@ -543,7 +544,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
self._logical_device_stack.append(logical_device_id)
|
||||
try:
|
||||
if values._enclosing_tpu_context() is None: # pylint: disable=protected-access
|
||||
if tpu_values.enclosing_tpu_context() is None:
|
||||
yield
|
||||
else:
|
||||
with ops.device(tpu.core(logical_device_id)):
|
||||
@ -648,20 +649,20 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||
v = next_creator(**kwargs)
|
||||
|
||||
assert not isinstance(v, values.TPUMirroredVariable)
|
||||
assert not isinstance(v, tpu_values.TPUMirroredVariable)
|
||||
value_list.append(v)
|
||||
return value_list
|
||||
|
||||
return values.create_mirrored_variable(self._container_strategy(),
|
||||
_real_mirrored_creator,
|
||||
values.TPUMirroredVariable,
|
||||
values.TPUSyncOnReadVariable,
|
||||
tpu_values.TPUMirroredVariable,
|
||||
tpu_values.TPUSyncOnReadVariable,
|
||||
**kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
if (isinstance(value, values.DistributedValues) or
|
||||
tensor_util.is_tensor(value)
|
||||
) and values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
||||
) and tpu_values.enclosing_tpu_context() is not None:
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
# TODO(jhseu): Revisit once we support model-parallelism.
|
||||
value *= (1. / self._num_replicas_in_sync)
|
||||
@ -701,9 +702,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
return output
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
assert isinstance(var, values.TPUVariableMixin) or isinstance(
|
||||
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
|
||||
var, resource_variable_ops.BaseResourceVariable)
|
||||
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
||||
if tpu_values.enclosing_tpu_context() is not None:
|
||||
if group:
|
||||
return fn(var, *args, **kwargs)
|
||||
else:
|
||||
@ -724,7 +725,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
return values.update_regroup(self, updates, group)
|
||||
|
||||
def read_var(self, var):
|
||||
assert isinstance(var, values.TPUVariableMixin) or isinstance(
|
||||
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
|
||||
var, resource_variable_ops.BaseResourceVariable)
|
||||
return var.read_value()
|
||||
|
||||
@ -745,7 +746,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
# since the `1` gets broadcast as an int32 but global_step is int64.
|
||||
if isinstance(tensor, (float, int)):
|
||||
return tensor
|
||||
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
||||
if tpu_values.enclosing_tpu_context() is not None:
|
||||
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
|
||||
result = tpu_ops.all_to_all(
|
||||
broadcast_tensor,
|
||||
|
245
tensorflow/python/distribute/tpu_values.py
Normal file
245
tensorflow/python/distribute/tpu_values.py
Normal file
@ -0,0 +1,245 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Various classes representing TPU distributed values.
|
||||
|
||||
Note that the tests are in values_test.py .
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.tpu import tpu
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_enter_graph(tensor):
|
||||
# Note: might have an eager tensor but not be executing eagerly when
|
||||
# building functions.
|
||||
if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
|
||||
ops.has_default_graph()):
|
||||
yield
|
||||
else:
|
||||
with tensor.graph.as_default():
|
||||
yield
|
||||
|
||||
|
||||
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
||||
|
||||
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
|
||||
del use_locking # Unused.
|
||||
|
||||
with _maybe_enter_graph(var.handle):
|
||||
op = raw_assign_fn(
|
||||
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
|
||||
|
||||
with ops.control_dependencies([op]):
|
||||
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
||||
|
||||
return assign_fn
|
||||
|
||||
|
||||
class TPUVariableMixin(object):
|
||||
"""Mixin for TPU variables."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TPUVariableMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
|
||||
# correctly since in eager mode different variables can have the same name.
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
self._handle_id = self._common_name + "_" + str(id(self._primary))
|
||||
else:
|
||||
self._handle_id = self._common_name
|
||||
|
||||
def __getattr__(self, name):
|
||||
if enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).__getattr__(name)
|
||||
else:
|
||||
raise AttributeError(
|
||||
"'{}' not accessible within a TPU context.".format(name))
|
||||
|
||||
def get(self):
|
||||
if enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).get()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"`TPUVariableMixin.get()` is not supported within a TPU context.")
|
||||
|
||||
def _get_as_operand(self):
|
||||
return self.read_value()
|
||||
|
||||
def _get_closest(self):
|
||||
if enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self)._get_closest()
|
||||
else:
|
||||
return self._primary
|
||||
|
||||
def numpy(self):
|
||||
if context.executing_eagerly():
|
||||
return self.read_value().numpy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def _is_mirrored(self):
|
||||
raise NotImplementedError(
|
||||
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||
tpu_context = enclosing_tpu_context()
|
||||
if tpu_context is None:
|
||||
return self._get_closest().handle
|
||||
else:
|
||||
return tpu_context.get_replicated_var_handle(self._handle_id,
|
||||
self._values,
|
||||
self._is_mirrored())
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.handle.device
|
||||
|
||||
def _read_variable_op(self):
|
||||
if self.trainable:
|
||||
tape.variable_accessed(self)
|
||||
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
|
||||
|
||||
def read_value(self):
|
||||
if enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).read_value()
|
||||
else:
|
||||
return self._read_variable_op()
|
||||
|
||||
def value(self):
|
||||
if enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).value()
|
||||
else:
|
||||
return self._read_variable_op()
|
||||
|
||||
def _as_graph_element(self):
|
||||
if enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return values.DistributedVarOp(self._primary.op.name,
|
||||
self._primary.op.graph,
|
||||
self._primary.op.traceback,
|
||||
self._primary.op.type)
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
# pylint: disable=protected-access
|
||||
if enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self)._dense_var_to_tensor(
|
||||
dtype=dtype, name=name, as_ref=as_ref)
|
||||
# pylint: enable=protected-access
|
||||
elif dtype is not None and dtype != self.dtype:
|
||||
return math_ops.cast(self.read_value(), dtype)
|
||||
else:
|
||||
return self.handle if as_ref else self.read_value()
|
||||
|
||||
|
||||
def enclosing_tpu_context():
|
||||
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
|
||||
graph = ops.get_default_graph()
|
||||
while graph is not None:
|
||||
# pylint: disable=protected-access
|
||||
context_ = graph._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while context_ is not None:
|
||||
if isinstance(context_, tpu.TPUReplicateContext):
|
||||
return context_
|
||||
context_ = context_.outer_context
|
||||
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
||||
# find the original graph with the XLAControlFlowContext.
|
||||
graph = getattr(graph, "outer_graph", None)
|
||||
return None
|
||||
|
||||
|
||||
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
||||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
(enclosing_tpu_context() is not None)):
|
||||
f = kwargs.pop("f")
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
else:
|
||||
return values.MirroredVariable._assign_func(self, *args, **kwargs)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
assign_sub_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_sub_variable_op)
|
||||
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
assign_add_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_add_variable_op)
|
||||
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
assign_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_variable_op)
|
||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||
|
||||
def _is_mirrored(self):
|
||||
return True
|
||||
|
||||
|
||||
class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
|
||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
if enclosing_tpu_context() is None:
|
||||
return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
|
||||
else:
|
||||
return _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
|
||||
**kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
if enclosing_tpu_context() is None:
|
||||
return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
|
||||
else:
|
||||
return _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_add_variable_op)(self, *args,
|
||||
**kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
if enclosing_tpu_context() is None:
|
||||
return values.SyncOnReadVariable.assign(self, *args, **kwargs)
|
||||
else:
|
||||
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||
self, *args, **kwargs)
|
||||
|
||||
def _is_mirrored(self):
|
||||
return False
|
@ -19,12 +19,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
@ -34,11 +33,9 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
@ -48,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
def _get_current_replica_id_as_int():
|
||||
"""Returns the current replica ID as an integer, or `None`."""
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if replica_context:
|
||||
replica_id = replica_context.replica_id_in_sync_group
|
||||
if not isinstance(replica_id, int):
|
||||
@ -362,7 +359,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
|
||||
return self._value_specs
|
||||
|
||||
def _to_components(self, value):
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if replica_context is not None and replica_context.num_replicas_in_sync > 1:
|
||||
raise ValueError(
|
||||
"Flattening a PerReplica to components is not supported in replica "
|
||||
@ -405,27 +402,6 @@ def _assign_sub_on_device(device, variable, tensor):
|
||||
return variable.assign_sub(tensor)
|
||||
|
||||
|
||||
def _assert_strategy(strategy):
|
||||
if not distribution_strategy_context.has_strategy():
|
||||
raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
|
||||
(strategy,))
|
||||
current_strategy = distribution_strategy_context.get_strategy()
|
||||
if current_strategy is not strategy:
|
||||
raise RuntimeError(
|
||||
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
||||
(current_strategy, strategy))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _enter_or_assert_strategy(strategy):
|
||||
if not distribution_strategy_context.has_strategy():
|
||||
with strategy.scope():
|
||||
yield
|
||||
else:
|
||||
_assert_strategy(strategy)
|
||||
yield
|
||||
|
||||
|
||||
DistributedVarOp = collections.namedtuple(
|
||||
"DistributedVarOp", ["name", "graph", "traceback", "type"])
|
||||
|
||||
@ -578,7 +554,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
# We want cross-replica code that does some var.op.X calls
|
||||
# to work (even if the current device isn't in self._devices), but
|
||||
# other uses of var.op in a cross-replica context to fail.
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
if ds_context.in_cross_replica_context():
|
||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
||||
self._primary.op.traceback, self._primary.op.type)
|
||||
return self._get().op
|
||||
@ -588,7 +564,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
return self._primary._in_graph_mode # pylint: disable=protected-access
|
||||
|
||||
def read_value(self):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return array_ops.identity(self._get())
|
||||
|
||||
def value(self):
|
||||
@ -602,135 +578,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
ops.register_dense_tensor_like_type(DistributedVariable)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_enter_graph(tensor):
|
||||
# Note: might have an eager tensor but not be executing eagerly when
|
||||
# building functions.
|
||||
if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
|
||||
ops.has_default_graph()):
|
||||
yield
|
||||
else:
|
||||
with tensor.graph.as_default():
|
||||
yield
|
||||
|
||||
|
||||
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
||||
|
||||
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
|
||||
del use_locking # Unused.
|
||||
|
||||
with _maybe_enter_graph(var.handle):
|
||||
op = raw_assign_fn(
|
||||
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
|
||||
|
||||
with ops.control_dependencies([op]):
|
||||
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
||||
|
||||
return assign_fn
|
||||
|
||||
|
||||
class TPUVariableMixin(object):
|
||||
"""Mixin for TPU variables."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TPUVariableMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
|
||||
# correctly since in eager mode different variables can have the same name.
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
self._handle_id = self._common_name + "_" + str(id(self._primary))
|
||||
else:
|
||||
self._handle_id = self._common_name
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).__getattr__(name)
|
||||
else:
|
||||
raise AttributeError(
|
||||
"'{}' not accessible within a TPU context.".format(name))
|
||||
|
||||
def get(self):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).get()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"`TPUVariableMixin.get()` is not supported within a TPU context.")
|
||||
|
||||
def _get_as_operand(self):
|
||||
return self.read_value()
|
||||
|
||||
def _get_closest(self):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self)._get_closest()
|
||||
else:
|
||||
return self._primary
|
||||
|
||||
def numpy(self):
|
||||
if context.executing_eagerly():
|
||||
return self.read_value().numpy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def _is_mirrored(self):
|
||||
raise NotImplementedError(
|
||||
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||
tpu_context = _enclosing_tpu_context()
|
||||
if tpu_context is None:
|
||||
return self._get_closest().handle
|
||||
else:
|
||||
return tpu_context.get_replicated_var_handle(
|
||||
self._handle_id, self._values, self._is_mirrored())
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.handle.device
|
||||
|
||||
def _read_variable_op(self):
|
||||
if self.trainable:
|
||||
tape.variable_accessed(self)
|
||||
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
|
||||
|
||||
def read_value(self):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).read_value()
|
||||
else:
|
||||
return self._read_variable_op()
|
||||
|
||||
def value(self):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self).value()
|
||||
else:
|
||||
return self._read_variable_op()
|
||||
|
||||
def _as_graph_element(self):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
||||
self._primary.op.traceback, self._primary.op.type)
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
# pylint: disable=protected-access
|
||||
if _enclosing_tpu_context() is None:
|
||||
return super(TPUVariableMixin, self)._dense_var_to_tensor(
|
||||
dtype=dtype, name=name, as_ref=as_ref)
|
||||
# pylint: enable=protected-access
|
||||
elif dtype is not None and dtype != self.dtype:
|
||||
return math_ops.cast(self.read_value(), dtype)
|
||||
else:
|
||||
return self.handle if as_ref else self.read_value()
|
||||
|
||||
|
||||
def _validate_colocate_extended(v, extended):
|
||||
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
|
||||
if variable_strategy.extended is not extended:
|
||||
@ -888,9 +735,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
# update_non_slot() function (like OptimizerV2._finish), which can
|
||||
# update several non-slot variables in one call.
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
f = kwargs.pop("f")
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
if ds_context.in_cross_replica_context():
|
||||
update_replica_id = distribute_lib.get_update_replica_id()
|
||||
if update_replica_id is not None:
|
||||
# We are calling an assign function on the mirrored variable in an
|
||||
@ -933,7 +780,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
return strategy.extended.update(
|
||||
self, f, args=(v,) + other_args, kwargs=other_kwargs)
|
||||
|
||||
return distribution_strategy_context.get_replica_context().merge_call(
|
||||
return ds_context.get_replica_context().merge_call(
|
||||
merge_fn, args=args, kwargs=kwargs)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
@ -1003,60 +850,11 @@ ops.register_tensor_conversion_function(Mirrored,
|
||||
_tensor_conversion_mirrored_val)
|
||||
|
||||
|
||||
def _enclosing_tpu_context():
|
||||
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
|
||||
graph = ops.get_default_graph()
|
||||
while graph is not None:
|
||||
# pylint: disable=protected-access
|
||||
context_ = graph._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while context_ is not None:
|
||||
if isinstance(context_, tpu.TPUReplicateContext):
|
||||
return context_
|
||||
context_ = context_.outer_context
|
||||
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
||||
# find the original graph with the XLAControlFlowContext.
|
||||
graph = getattr(graph, "outer_graph", None)
|
||||
return None
|
||||
|
||||
|
||||
def is_distributed_variable(v):
|
||||
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
||||
return isinstance(v, DistributedVariable)
|
||||
|
||||
|
||||
class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
|
||||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
if (distribution_strategy_context.in_cross_replica_context() and
|
||||
(_enclosing_tpu_context() is not None)):
|
||||
f = kwargs.pop("f")
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
else:
|
||||
return MirroredVariable._assign_func(self, *args, **kwargs)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
assign_sub_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_sub_variable_op)
|
||||
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
assign_add_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_add_variable_op)
|
||||
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
assign_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_variable_op)
|
||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||
|
||||
def _is_mirrored(self):
|
||||
return True
|
||||
|
||||
|
||||
class _SyncOnReadSaveable(saveable_object.SaveableObject):
|
||||
"""Class for defining how to restore a SyncOnReadVariable."""
|
||||
|
||||
@ -1094,7 +892,7 @@ class _SyncOnReadSaveable(saveable_object.SaveableObject):
|
||||
|
||||
|
||||
def _assert_replica_context(strategy):
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if not replica_context:
|
||||
raise RuntimeError(
|
||||
"Replica-local variables may only be assigned in a replica context.")
|
||||
@ -1111,8 +909,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
self._aggregation = aggregation
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_sub` in "
|
||||
@ -1126,8 +924,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
return self._get().assign_sub(*args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
raise ValueError(
|
||||
"SyncOnReadVariable does not support `assign_add` in "
|
||||
@ -1141,8 +939,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
return self._get().assign_add(*args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
# To preserve the sum across save and restore, we have to divide the
|
||||
# total across all devices when restoring a variable that was summed
|
||||
# when saving.
|
||||
@ -1155,8 +953,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
return self._get().assign(*args, **kwargs)
|
||||
|
||||
def value(self):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return self._get_cross_replica()
|
||||
else:
|
||||
# _get_closest() returns a Variable.
|
||||
@ -1177,7 +975,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return self._primary
|
||||
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return self._distribute_strategy.reduce(
|
||||
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
|
||||
self,
|
||||
@ -1185,8 +983,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
|
||||
def _as_graph_element(self):
|
||||
# pylint: disable=protected-access
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
return ops.convert_to_tensor(self._get_cross_replica())
|
||||
return self._get()._as_graph_element()
|
||||
|
||||
@ -1207,7 +1005,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return ops.convert_to_tensor(
|
||||
self._get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||
|
||||
@ -1222,36 +1020,6 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
|
||||
_tensor_conversion_sync_on_read)
|
||||
|
||||
|
||||
class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
|
||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return SyncOnReadVariable.assign_sub(self, *args, **kwargs)
|
||||
else:
|
||||
return _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
|
||||
**kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return SyncOnReadVariable.assign_add(self, *args, **kwargs)
|
||||
else:
|
||||
return _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_add_variable_op)(self, *args,
|
||||
**kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
if _enclosing_tpu_context() is None:
|
||||
return SyncOnReadVariable.assign(self, *args, **kwargs)
|
||||
else:
|
||||
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||
self, *args, **kwargs)
|
||||
|
||||
def _is_mirrored(self):
|
||||
return False
|
||||
|
||||
|
||||
def regroup(values, wrap_class=PerReplica):
|
||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
||||
v0 = values[0]
|
||||
@ -1444,9 +1212,9 @@ class AggregatingVariable(variables_lib.Variable):
|
||||
return getattr(self._v, name)
|
||||
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
f = kwargs.pop("f")
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
if ds_context.in_cross_replica_context():
|
||||
if distribute_lib.get_update_replica_id() is not None:
|
||||
# We are calling an assign function in an update context.
|
||||
return f(self._v, *args, **kwargs)
|
||||
@ -1456,7 +1224,7 @@ class AggregatingVariable(variables_lib.Variable):
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
else:
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
replica_context = ds_context.get_replica_context()
|
||||
assert replica_context
|
||||
# We are calling an assign function in replica context.
|
||||
# We reduce the value we want to assign/add/sub. More details about how
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
@ -824,7 +825,7 @@ def _make_replica_local(method, strategy=None):
|
||||
name=n, initializer=init, use_resource=True))
|
||||
|
||||
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
|
||||
var_cls = values.TPUSyncOnReadVariable
|
||||
var_cls = tpu_values.TPUSyncOnReadVariable
|
||||
else:
|
||||
var_cls = values.SyncOnReadVariable
|
||||
replica_local = var_cls(strategy, v, method)
|
||||
|
@ -28,6 +28,7 @@ tf_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/distribute:tpu_values",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ from absl.testing import parameterized
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as distributed_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -249,10 +250,8 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
|
||||
def test_supports_distributed_variables(self):
|
||||
mirrored = distributed_values.MirroredVariable(
|
||||
None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
|
||||
tpu = distributed_values.TPUMirroredVariable(
|
||||
strategy=None,
|
||||
values=[variables.Variable(42.)],
|
||||
aggregation=None)
|
||||
tpu = tpu_values.TPUMirroredVariable(
|
||||
strategy=None, values=[variables.Variable(42.)], aggregation=None)
|
||||
aggregating = distributed_values.AggregatingVariable(
|
||||
strategy=None, v=variables.Variable(1.), aggregation=None)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user