Separate TPU variables into its own file
TPU code pulls in tensorflow.python.tpu.tpu which pulls a large number of dependencies. It's easy to create cyclic dependencies. PiperOrigin-RevId: 297060171 Change-Id: I8e09e7b8de9fc50e892502a480de8fffa378a1c5
This commit is contained in:
parent
2e81bc66c5
commit
db992d30bb
@ -2882,6 +2882,7 @@ tf_gen_op_wrapper_private_py(
|
|||||||
name = "resource_variable_ops_gen",
|
name = "resource_variable_ops_gen",
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/compiler/tf2xla:internal",
|
"//tensorflow/compiler/tf2xla:internal",
|
||||||
|
"//tensorflow/python/distribute:__pkg__",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -66,6 +66,7 @@ py_library(
|
|||||||
":cross_device_utils",
|
":cross_device_utils",
|
||||||
":device_util",
|
":device_util",
|
||||||
":reduce_util",
|
":reduce_util",
|
||||||
|
":tpu_values",
|
||||||
":values",
|
":values",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:device_lib",
|
"//tensorflow/python:device_lib",
|
||||||
@ -531,6 +532,7 @@ py_library(
|
|||||||
":input_lib",
|
":input_lib",
|
||||||
":numpy_dataset",
|
":numpy_dataset",
|
||||||
":reduce_util",
|
":reduce_util",
|
||||||
|
":tpu_values",
|
||||||
":values",
|
":values",
|
||||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
@ -612,18 +614,36 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":device_util",
|
":device_util",
|
||||||
":distribute_lib",
|
":distribute_lib",
|
||||||
|
":reduce_util",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:composite_tensor",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:dtypes",
|
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:tensor_util",
|
||||||
|
"//tensorflow/python:type_spec",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/tpu:tpu_lib",
|
"//tensorflow/python/eager:tape",
|
||||||
"//tensorflow/python/training/saving:saveable_object",
|
"//tensorflow/python/training/saving:saveable_object",
|
||||||
"//tensorflow/python/training/saving:saveable_object_util",
|
"//tensorflow/python/training/saving:saveable_object_util",
|
||||||
"//tensorflow/python/training/tracking:base",
|
"//tensorflow/python/training/tracking:base",
|
||||||
"@six_archive//:six",
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tpu_values",
|
||||||
|
srcs = ["tpu_values.py"],
|
||||||
|
deps = [
|
||||||
|
":values",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:resource_variable_ops_gen",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/eager:tape",
|
||||||
|
"//tensorflow/python/tpu:tpu_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.client import device_lib
|
|||||||
from tensorflow.python.distribute import cross_device_utils
|
from tensorflow.python.distribute import cross_device_utils
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
|
from tensorflow.python.distribute import tpu_values
|
||||||
from tensorflow.python.distribute import values as value_lib
|
from tensorflow.python.distribute import values as value_lib
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -62,8 +63,8 @@ def validate_destinations(destinations):
|
|||||||
if not isinstance(
|
if not isinstance(
|
||||||
destinations,
|
destinations,
|
||||||
(value_lib.DistributedValues, ops.Tensor, value_lib.AggregatingVariable,
|
(value_lib.DistributedValues, ops.Tensor, value_lib.AggregatingVariable,
|
||||||
six.string_types, value_lib.TPUMirroredVariable)
|
six.string_types, tpu_values.TPUMirroredVariable
|
||||||
) and not resource_variable_ops.is_resource_variable(destinations):
|
)) and not resource_variable_ops.is_resource_variable(destinations):
|
||||||
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
||||||
" a tf.Variable object, or a device string.")
|
" a tf.Variable object, or a device string.")
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -266,6 +267,20 @@ def experimental_set_strategy(strategy):
|
|||||||
ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access
|
ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Internal helpers.
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def enter_or_assert_strategy(strategy):
|
||||||
|
if not has_strategy():
|
||||||
|
with strategy.scope():
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
_assert_strategy(strategy)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
# Defaults that are used when no tf.distribute.Strategy is explicitly created.
|
# Defaults that are used when no tf.distribute.Strategy is explicitly created.
|
||||||
# We create them lazily in a function so that we can workaround the circular
|
# We create them lazily in a function so that we can workaround the circular
|
||||||
@ -284,6 +299,17 @@ _default_replica_context_lock = threading.Lock()
|
|||||||
_default_replica_mode_lock = threading.Lock()
|
_default_replica_mode_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_strategy(strategy):
|
||||||
|
if not has_strategy():
|
||||||
|
raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
|
||||||
|
(strategy,))
|
||||||
|
current_strategy = get_strategy()
|
||||||
|
if current_strategy is not strategy:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
||||||
|
(current_strategy, strategy))
|
||||||
|
|
||||||
|
|
||||||
def _get_default_strategy():
|
def _get_default_strategy():
|
||||||
if _defaults["strategy"] is None:
|
if _defaults["strategy"] is None:
|
||||||
# Avoid race condition causing two defaults to be created
|
# Avoid race condition causing two defaults to be created
|
||||||
|
@ -34,6 +34,7 @@ from tensorflow.python.distribute import distribute_lib
|
|||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import numpy_dataset
|
from tensorflow.python.distribute import numpy_dataset
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
|
from tensorflow.python.distribute import tpu_values
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -543,7 +544,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
|
|
||||||
self._logical_device_stack.append(logical_device_id)
|
self._logical_device_stack.append(logical_device_id)
|
||||||
try:
|
try:
|
||||||
if values._enclosing_tpu_context() is None: # pylint: disable=protected-access
|
if tpu_values.enclosing_tpu_context() is None:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
with ops.device(tpu.core(logical_device_id)):
|
with ops.device(tpu.core(logical_device_id)):
|
||||||
@ -648,20 +649,20 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||||
v = next_creator(**kwargs)
|
v = next_creator(**kwargs)
|
||||||
|
|
||||||
assert not isinstance(v, values.TPUMirroredVariable)
|
assert not isinstance(v, tpu_values.TPUMirroredVariable)
|
||||||
value_list.append(v)
|
value_list.append(v)
|
||||||
return value_list
|
return value_list
|
||||||
|
|
||||||
return values.create_mirrored_variable(self._container_strategy(),
|
return values.create_mirrored_variable(self._container_strategy(),
|
||||||
_real_mirrored_creator,
|
_real_mirrored_creator,
|
||||||
values.TPUMirroredVariable,
|
tpu_values.TPUMirroredVariable,
|
||||||
values.TPUSyncOnReadVariable,
|
tpu_values.TPUSyncOnReadVariable,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def _reduce_to(self, reduce_op, value, destinations):
|
def _reduce_to(self, reduce_op, value, destinations):
|
||||||
if (isinstance(value, values.DistributedValues) or
|
if (isinstance(value, values.DistributedValues) or
|
||||||
tensor_util.is_tensor(value)
|
tensor_util.is_tensor(value)
|
||||||
) and values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
) and tpu_values.enclosing_tpu_context() is not None:
|
||||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
# TODO(jhseu): Revisit once we support model-parallelism.
|
# TODO(jhseu): Revisit once we support model-parallelism.
|
||||||
value *= (1. / self._num_replicas_in_sync)
|
value *= (1. / self._num_replicas_in_sync)
|
||||||
@ -701,9 +702,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def _update(self, var, fn, args, kwargs, group):
|
def _update(self, var, fn, args, kwargs, group):
|
||||||
assert isinstance(var, values.TPUVariableMixin) or isinstance(
|
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
|
||||||
var, resource_variable_ops.BaseResourceVariable)
|
var, resource_variable_ops.BaseResourceVariable)
|
||||||
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
if tpu_values.enclosing_tpu_context() is not None:
|
||||||
if group:
|
if group:
|
||||||
return fn(var, *args, **kwargs)
|
return fn(var, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -724,7 +725,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
return values.update_regroup(self, updates, group)
|
return values.update_regroup(self, updates, group)
|
||||||
|
|
||||||
def read_var(self, var):
|
def read_var(self, var):
|
||||||
assert isinstance(var, values.TPUVariableMixin) or isinstance(
|
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
|
||||||
var, resource_variable_ops.BaseResourceVariable)
|
var, resource_variable_ops.BaseResourceVariable)
|
||||||
return var.read_value()
|
return var.read_value()
|
||||||
|
|
||||||
@ -745,7 +746,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
# since the `1` gets broadcast as an int32 but global_step is int64.
|
# since the `1` gets broadcast as an int32 but global_step is int64.
|
||||||
if isinstance(tensor, (float, int)):
|
if isinstance(tensor, (float, int)):
|
||||||
return tensor
|
return tensor
|
||||||
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
if tpu_values.enclosing_tpu_context() is not None:
|
||||||
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
|
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
|
||||||
result = tpu_ops.all_to_all(
|
result = tpu_ops.all_to_all(
|
||||||
broadcast_tensor,
|
broadcast_tensor,
|
||||||
|
245
tensorflow/python/distribute/tpu_values.py
Normal file
245
tensorflow/python/distribute/tpu_values.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Various classes representing TPU distributed values.
|
||||||
|
|
||||||
|
Note that the tests are in values_test.py .
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
|
from tensorflow.python.distribute import values
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import tape
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import gen_resource_variable_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.tpu import tpu
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _maybe_enter_graph(tensor):
|
||||||
|
# Note: might have an eager tensor but not be executing eagerly when
|
||||||
|
# building functions.
|
||||||
|
if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
|
||||||
|
ops.has_default_graph()):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
with tensor.graph.as_default():
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
||||||
|
|
||||||
|
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
|
||||||
|
del use_locking # Unused.
|
||||||
|
|
||||||
|
with _maybe_enter_graph(var.handle):
|
||||||
|
op = raw_assign_fn(
|
||||||
|
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
|
||||||
|
|
||||||
|
with ops.control_dependencies([op]):
|
||||||
|
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
||||||
|
|
||||||
|
return assign_fn
|
||||||
|
|
||||||
|
|
||||||
|
class TPUVariableMixin(object):
|
||||||
|
"""Mixin for TPU variables."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(TPUVariableMixin, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
|
||||||
|
# correctly since in eager mode different variables can have the same name.
|
||||||
|
if ops.executing_eagerly_outside_functions():
|
||||||
|
self._handle_id = self._common_name + "_" + str(id(self._primary))
|
||||||
|
else:
|
||||||
|
self._handle_id = self._common_name
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return super(TPUVariableMixin, self).__getattr__(name)
|
||||||
|
else:
|
||||||
|
raise AttributeError(
|
||||||
|
"'{}' not accessible within a TPU context.".format(name))
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return super(TPUVariableMixin, self).get()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`TPUVariableMixin.get()` is not supported within a TPU context.")
|
||||||
|
|
||||||
|
def _get_as_operand(self):
|
||||||
|
return self.read_value()
|
||||||
|
|
||||||
|
def _get_closest(self):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return super(TPUVariableMixin, self)._get_closest()
|
||||||
|
else:
|
||||||
|
return self._primary
|
||||||
|
|
||||||
|
def numpy(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return self.read_value().numpy()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"numpy() is only available when eager execution is enabled.")
|
||||||
|
|
||||||
|
def _is_mirrored(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def handle(self):
|
||||||
|
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||||
|
tpu_context = enclosing_tpu_context()
|
||||||
|
if tpu_context is None:
|
||||||
|
return self._get_closest().handle
|
||||||
|
else:
|
||||||
|
return tpu_context.get_replicated_var_handle(self._handle_id,
|
||||||
|
self._values,
|
||||||
|
self._is_mirrored())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.handle.device
|
||||||
|
|
||||||
|
def _read_variable_op(self):
|
||||||
|
if self.trainable:
|
||||||
|
tape.variable_accessed(self)
|
||||||
|
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
|
||||||
|
|
||||||
|
def read_value(self):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return super(TPUVariableMixin, self).read_value()
|
||||||
|
else:
|
||||||
|
return self._read_variable_op()
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return super(TPUVariableMixin, self).value()
|
||||||
|
else:
|
||||||
|
return self._read_variable_op()
|
||||||
|
|
||||||
|
def _as_graph_element(self):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def op(self):
|
||||||
|
return values.DistributedVarOp(self._primary.op.name,
|
||||||
|
self._primary.op.graph,
|
||||||
|
self._primary.op.traceback,
|
||||||
|
self._primary.op.type)
|
||||||
|
|
||||||
|
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||||
|
"""Converts a variable to a tensor."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return super(TPUVariableMixin, self)._dense_var_to_tensor(
|
||||||
|
dtype=dtype, name=name, as_ref=as_ref)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
elif dtype is not None and dtype != self.dtype:
|
||||||
|
return math_ops.cast(self.read_value(), dtype)
|
||||||
|
else:
|
||||||
|
return self.handle if as_ref else self.read_value()
|
||||||
|
|
||||||
|
|
||||||
|
def enclosing_tpu_context():
|
||||||
|
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
while graph is not None:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
context_ = graph._get_control_flow_context()
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
while context_ is not None:
|
||||||
|
if isinstance(context_, tpu.TPUReplicateContext):
|
||||||
|
return context_
|
||||||
|
context_ = context_.outer_context
|
||||||
|
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
||||||
|
# find the original graph with the XLAControlFlowContext.
|
||||||
|
graph = getattr(graph, "outer_graph", None)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
||||||
|
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||||
|
|
||||||
|
def _assign_func(self, *args, **kwargs):
|
||||||
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
|
if (ds_context.in_cross_replica_context() and
|
||||||
|
(enclosing_tpu_context() is not None)):
|
||||||
|
f = kwargs.pop("f")
|
||||||
|
return self._distribute_strategy.extended.update(
|
||||||
|
self, f, args=args, kwargs=kwargs)
|
||||||
|
else:
|
||||||
|
return values.MirroredVariable._assign_func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
def assign_sub(self, *args, **kwargs):
|
||||||
|
assign_sub_fn = _make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_sub_variable_op)
|
||||||
|
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
||||||
|
|
||||||
|
def assign_add(self, *args, **kwargs):
|
||||||
|
assign_add_fn = _make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_add_variable_op)
|
||||||
|
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
||||||
|
|
||||||
|
def assign(self, *args, **kwargs):
|
||||||
|
assign_fn = _make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_variable_op)
|
||||||
|
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||||
|
|
||||||
|
def _is_mirrored(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
|
||||||
|
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||||
|
|
||||||
|
def assign_sub(self, *args, **kwargs):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return _make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def assign_add(self, *args, **kwargs):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return _make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_add_variable_op)(self, *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def assign(self, *args, **kwargs):
|
||||||
|
if enclosing_tpu_context() is None:
|
||||||
|
return values.SyncOnReadVariable.assign(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||||
|
self, *args, **kwargs)
|
||||||
|
|
||||||
|
def _is_mirrored(self):
|
||||||
|
return False
|
@ -19,12 +19,11 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
@ -34,11 +33,9 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.framework import type_spec
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_resource_variable_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.tpu import tpu
|
|
||||||
from tensorflow.python.training.saving import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
@ -48,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
def _get_current_replica_id_as_int():
|
def _get_current_replica_id_as_int():
|
||||||
"""Returns the current replica ID as an integer, or `None`."""
|
"""Returns the current replica ID as an integer, or `None`."""
|
||||||
replica_context = distribution_strategy_context.get_replica_context()
|
replica_context = ds_context.get_replica_context()
|
||||||
if replica_context:
|
if replica_context:
|
||||||
replica_id = replica_context.replica_id_in_sync_group
|
replica_id = replica_context.replica_id_in_sync_group
|
||||||
if not isinstance(replica_id, int):
|
if not isinstance(replica_id, int):
|
||||||
@ -362,7 +359,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
|
|||||||
return self._value_specs
|
return self._value_specs
|
||||||
|
|
||||||
def _to_components(self, value):
|
def _to_components(self, value):
|
||||||
replica_context = distribution_strategy_context.get_replica_context()
|
replica_context = ds_context.get_replica_context()
|
||||||
if replica_context is not None and replica_context.num_replicas_in_sync > 1:
|
if replica_context is not None and replica_context.num_replicas_in_sync > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Flattening a PerReplica to components is not supported in replica "
|
"Flattening a PerReplica to components is not supported in replica "
|
||||||
@ -405,27 +402,6 @@ def _assign_sub_on_device(device, variable, tensor):
|
|||||||
return variable.assign_sub(tensor)
|
return variable.assign_sub(tensor)
|
||||||
|
|
||||||
|
|
||||||
def _assert_strategy(strategy):
|
|
||||||
if not distribution_strategy_context.has_strategy():
|
|
||||||
raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
|
|
||||||
(strategy,))
|
|
||||||
current_strategy = distribution_strategy_context.get_strategy()
|
|
||||||
if current_strategy is not strategy:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
|
||||||
(current_strategy, strategy))
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _enter_or_assert_strategy(strategy):
|
|
||||||
if not distribution_strategy_context.has_strategy():
|
|
||||||
with strategy.scope():
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
_assert_strategy(strategy)
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
DistributedVarOp = collections.namedtuple(
|
DistributedVarOp = collections.namedtuple(
|
||||||
"DistributedVarOp", ["name", "graph", "traceback", "type"])
|
"DistributedVarOp", ["name", "graph", "traceback", "type"])
|
||||||
|
|
||||||
@ -578,7 +554,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
# We want cross-replica code that does some var.op.X calls
|
# We want cross-replica code that does some var.op.X calls
|
||||||
# to work (even if the current device isn't in self._devices), but
|
# to work (even if the current device isn't in self._devices), but
|
||||||
# other uses of var.op in a cross-replica context to fail.
|
# other uses of var.op in a cross-replica context to fail.
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
||||||
self._primary.op.traceback, self._primary.op.type)
|
self._primary.op.traceback, self._primary.op.type)
|
||||||
return self._get().op
|
return self._get().op
|
||||||
@ -588,7 +564,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
return self._primary._in_graph_mode # pylint: disable=protected-access
|
return self._primary._in_graph_mode # pylint: disable=protected-access
|
||||||
|
|
||||||
def read_value(self):
|
def read_value(self):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
return array_ops.identity(self._get())
|
return array_ops.identity(self._get())
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
@ -602,135 +578,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
ops.register_dense_tensor_like_type(DistributedVariable)
|
ops.register_dense_tensor_like_type(DistributedVariable)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _maybe_enter_graph(tensor):
|
|
||||||
# Note: might have an eager tensor but not be executing eagerly when
|
|
||||||
# building functions.
|
|
||||||
if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
|
|
||||||
ops.has_default_graph()):
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
with tensor.graph.as_default():
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
|
||||||
|
|
||||||
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
|
|
||||||
del use_locking # Unused.
|
|
||||||
|
|
||||||
with _maybe_enter_graph(var.handle):
|
|
||||||
op = raw_assign_fn(
|
|
||||||
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
|
|
||||||
|
|
||||||
with ops.control_dependencies([op]):
|
|
||||||
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
|
||||||
|
|
||||||
return assign_fn
|
|
||||||
|
|
||||||
|
|
||||||
class TPUVariableMixin(object):
|
|
||||||
"""Mixin for TPU variables."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(TPUVariableMixin, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
# Handle ID is needed for `get_replicated_var_handle` to cache the variables
|
|
||||||
# correctly since in eager mode different variables can have the same name.
|
|
||||||
if ops.executing_eagerly_outside_functions():
|
|
||||||
self._handle_id = self._common_name + "_" + str(id(self._primary))
|
|
||||||
else:
|
|
||||||
self._handle_id = self._common_name
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return super(TPUVariableMixin, self).__getattr__(name)
|
|
||||||
else:
|
|
||||||
raise AttributeError(
|
|
||||||
"'{}' not accessible within a TPU context.".format(name))
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return super(TPUVariableMixin, self).get()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"`TPUVariableMixin.get()` is not supported within a TPU context.")
|
|
||||||
|
|
||||||
def _get_as_operand(self):
|
|
||||||
return self.read_value()
|
|
||||||
|
|
||||||
def _get_closest(self):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return super(TPUVariableMixin, self)._get_closest()
|
|
||||||
else:
|
|
||||||
return self._primary
|
|
||||||
|
|
||||||
def numpy(self):
|
|
||||||
if context.executing_eagerly():
|
|
||||||
return self.read_value().numpy()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"numpy() is only available when eager execution is enabled.")
|
|
||||||
|
|
||||||
def _is_mirrored(self):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def handle(self):
|
|
||||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
|
||||||
tpu_context = _enclosing_tpu_context()
|
|
||||||
if tpu_context is None:
|
|
||||||
return self._get_closest().handle
|
|
||||||
else:
|
|
||||||
return tpu_context.get_replicated_var_handle(
|
|
||||||
self._handle_id, self._values, self._is_mirrored())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.handle.device
|
|
||||||
|
|
||||||
def _read_variable_op(self):
|
|
||||||
if self.trainable:
|
|
||||||
tape.variable_accessed(self)
|
|
||||||
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
|
|
||||||
|
|
||||||
def read_value(self):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return super(TPUVariableMixin, self).read_value()
|
|
||||||
else:
|
|
||||||
return self._read_variable_op()
|
|
||||||
|
|
||||||
def value(self):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return super(TPUVariableMixin, self).value()
|
|
||||||
else:
|
|
||||||
return self._read_variable_op()
|
|
||||||
|
|
||||||
def _as_graph_element(self):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def op(self):
|
|
||||||
return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
|
|
||||||
self._primary.op.traceback, self._primary.op.type)
|
|
||||||
|
|
||||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
|
||||||
"""Converts a variable to a tensor."""
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return super(TPUVariableMixin, self)._dense_var_to_tensor(
|
|
||||||
dtype=dtype, name=name, as_ref=as_ref)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
elif dtype is not None and dtype != self.dtype:
|
|
||||||
return math_ops.cast(self.read_value(), dtype)
|
|
||||||
else:
|
|
||||||
return self.handle if as_ref else self.read_value()
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_colocate_extended(v, extended):
|
def _validate_colocate_extended(v, extended):
|
||||||
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
|
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
|
||||||
if variable_strategy.extended is not extended:
|
if variable_strategy.extended is not extended:
|
||||||
@ -888,9 +735,9 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
# update_non_slot() function (like OptimizerV2._finish), which can
|
# update_non_slot() function (like OptimizerV2._finish), which can
|
||||||
# update several non-slot variables in one call.
|
# update several non-slot variables in one call.
|
||||||
def _assign_func(self, *args, **kwargs):
|
def _assign_func(self, *args, **kwargs):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
f = kwargs.pop("f")
|
f = kwargs.pop("f")
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
update_replica_id = distribute_lib.get_update_replica_id()
|
update_replica_id = distribute_lib.get_update_replica_id()
|
||||||
if update_replica_id is not None:
|
if update_replica_id is not None:
|
||||||
# We are calling an assign function on the mirrored variable in an
|
# We are calling an assign function on the mirrored variable in an
|
||||||
@ -933,7 +780,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
|||||||
return strategy.extended.update(
|
return strategy.extended.update(
|
||||||
self, f, args=(v,) + other_args, kwargs=other_kwargs)
|
self, f, args=(v,) + other_args, kwargs=other_kwargs)
|
||||||
|
|
||||||
return distribution_strategy_context.get_replica_context().merge_call(
|
return ds_context.get_replica_context().merge_call(
|
||||||
merge_fn, args=args, kwargs=kwargs)
|
merge_fn, args=args, kwargs=kwargs)
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
def assign_sub(self, *args, **kwargs):
|
||||||
@ -1003,60 +850,11 @@ ops.register_tensor_conversion_function(Mirrored,
|
|||||||
_tensor_conversion_mirrored_val)
|
_tensor_conversion_mirrored_val)
|
||||||
|
|
||||||
|
|
||||||
def _enclosing_tpu_context():
|
|
||||||
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
|
|
||||||
graph = ops.get_default_graph()
|
|
||||||
while graph is not None:
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
context_ = graph._get_control_flow_context()
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
while context_ is not None:
|
|
||||||
if isinstance(context_, tpu.TPUReplicateContext):
|
|
||||||
return context_
|
|
||||||
context_ = context_.outer_context
|
|
||||||
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
|
||||||
# find the original graph with the XLAControlFlowContext.
|
|
||||||
graph = getattr(graph, "outer_graph", None)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def is_distributed_variable(v):
|
def is_distributed_variable(v):
|
||||||
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
"""Determine if a variable is ds variable or TPU mirrored variable."""
|
||||||
return isinstance(v, DistributedVariable)
|
return isinstance(v, DistributedVariable)
|
||||||
|
|
||||||
|
|
||||||
class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
|
|
||||||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
|
||||||
|
|
||||||
def _assign_func(self, *args, **kwargs):
|
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
|
||||||
if (distribution_strategy_context.in_cross_replica_context() and
|
|
||||||
(_enclosing_tpu_context() is not None)):
|
|
||||||
f = kwargs.pop("f")
|
|
||||||
return self._distribute_strategy.extended.update(
|
|
||||||
self, f, args=args, kwargs=kwargs)
|
|
||||||
else:
|
|
||||||
return MirroredVariable._assign_func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
|
||||||
assign_sub_fn = _make_raw_assign_fn(
|
|
||||||
gen_resource_variable_ops.assign_sub_variable_op)
|
|
||||||
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
|
||||||
|
|
||||||
def assign_add(self, *args, **kwargs):
|
|
||||||
assign_add_fn = _make_raw_assign_fn(
|
|
||||||
gen_resource_variable_ops.assign_add_variable_op)
|
|
||||||
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
|
||||||
|
|
||||||
def assign(self, *args, **kwargs):
|
|
||||||
assign_fn = _make_raw_assign_fn(
|
|
||||||
gen_resource_variable_ops.assign_variable_op)
|
|
||||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
|
||||||
|
|
||||||
def _is_mirrored(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class _SyncOnReadSaveable(saveable_object.SaveableObject):
|
class _SyncOnReadSaveable(saveable_object.SaveableObject):
|
||||||
"""Class for defining how to restore a SyncOnReadVariable."""
|
"""Class for defining how to restore a SyncOnReadVariable."""
|
||||||
|
|
||||||
@ -1094,7 +892,7 @@ class _SyncOnReadSaveable(saveable_object.SaveableObject):
|
|||||||
|
|
||||||
|
|
||||||
def _assert_replica_context(strategy):
|
def _assert_replica_context(strategy):
|
||||||
replica_context = distribution_strategy_context.get_replica_context()
|
replica_context = ds_context.get_replica_context()
|
||||||
if not replica_context:
|
if not replica_context:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Replica-local variables may only be assigned in a replica context.")
|
"Replica-local variables may only be assigned in a replica context.")
|
||||||
@ -1111,8 +909,8 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
self._aggregation = aggregation
|
self._aggregation = aggregation
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
def assign_sub(self, *args, **kwargs):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
if self._aggregation == vs.VariableAggregation.SUM:
|
if self._aggregation == vs.VariableAggregation.SUM:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"SyncOnReadVariable does not support `assign_sub` in "
|
"SyncOnReadVariable does not support `assign_sub` in "
|
||||||
@ -1126,8 +924,8 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
return self._get().assign_sub(*args, **kwargs)
|
return self._get().assign_sub(*args, **kwargs)
|
||||||
|
|
||||||
def assign_add(self, *args, **kwargs):
|
def assign_add(self, *args, **kwargs):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
if self._aggregation == vs.VariableAggregation.SUM:
|
if self._aggregation == vs.VariableAggregation.SUM:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"SyncOnReadVariable does not support `assign_add` in "
|
"SyncOnReadVariable does not support `assign_add` in "
|
||||||
@ -1141,8 +939,8 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
return self._get().assign_add(*args, **kwargs)
|
return self._get().assign_add(*args, **kwargs)
|
||||||
|
|
||||||
def assign(self, *args, **kwargs):
|
def assign(self, *args, **kwargs):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
# To preserve the sum across save and restore, we have to divide the
|
# To preserve the sum across save and restore, we have to divide the
|
||||||
# total across all devices when restoring a variable that was summed
|
# total across all devices when restoring a variable that was summed
|
||||||
# when saving.
|
# when saving.
|
||||||
@ -1155,8 +953,8 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
return self._get().assign(*args, **kwargs)
|
return self._get().assign(*args, **kwargs)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
return self._get_cross_replica()
|
return self._get_cross_replica()
|
||||||
else:
|
else:
|
||||||
# _get_closest() returns a Variable.
|
# _get_closest() returns a Variable.
|
||||||
@ -1177,7 +975,7 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||||
return self._primary
|
return self._primary
|
||||||
|
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
return self._distribute_strategy.reduce(
|
return self._distribute_strategy.reduce(
|
||||||
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
|
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
|
||||||
self,
|
self,
|
||||||
@ -1185,8 +983,8 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
|
|
||||||
def _as_graph_element(self):
|
def _as_graph_element(self):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
return ops.convert_to_tensor(self._get_cross_replica())
|
return ops.convert_to_tensor(self._get_cross_replica())
|
||||||
return self._get()._as_graph_element()
|
return self._get()._as_graph_element()
|
||||||
|
|
||||||
@ -1207,7 +1005,7 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
|
|
||||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||||
"""Converts a variable to a tensor."""
|
"""Converts a variable to a tensor."""
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
return ops.convert_to_tensor(
|
return ops.convert_to_tensor(
|
||||||
self._get(), dtype=dtype, name=name, as_ref=as_ref)
|
self._get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||||
|
|
||||||
@ -1222,36 +1020,6 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
|
|||||||
_tensor_conversion_sync_on_read)
|
_tensor_conversion_sync_on_read)
|
||||||
|
|
||||||
|
|
||||||
class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
|
|
||||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return SyncOnReadVariable.assign_sub(self, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return _make_raw_assign_fn(
|
|
||||||
gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def assign_add(self, *args, **kwargs):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return SyncOnReadVariable.assign_add(self, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return _make_raw_assign_fn(
|
|
||||||
gen_resource_variable_ops.assign_add_variable_op)(self, *args,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def assign(self, *args, **kwargs):
|
|
||||||
if _enclosing_tpu_context() is None:
|
|
||||||
return SyncOnReadVariable.assign(self, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
|
||||||
self, *args, **kwargs)
|
|
||||||
|
|
||||||
def _is_mirrored(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def regroup(values, wrap_class=PerReplica):
|
def regroup(values, wrap_class=PerReplica):
|
||||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
||||||
v0 = values[0]
|
v0 = values[0]
|
||||||
@ -1444,9 +1212,9 @@ class AggregatingVariable(variables_lib.Variable):
|
|||||||
return getattr(self._v, name)
|
return getattr(self._v, name)
|
||||||
|
|
||||||
def _assign_func(self, *args, **kwargs):
|
def _assign_func(self, *args, **kwargs):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
f = kwargs.pop("f")
|
f = kwargs.pop("f")
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
if distribute_lib.get_update_replica_id() is not None:
|
if distribute_lib.get_update_replica_id() is not None:
|
||||||
# We are calling an assign function in an update context.
|
# We are calling an assign function in an update context.
|
||||||
return f(self._v, *args, **kwargs)
|
return f(self._v, *args, **kwargs)
|
||||||
@ -1456,7 +1224,7 @@ class AggregatingVariable(variables_lib.Variable):
|
|||||||
return self._distribute_strategy.extended.update(
|
return self._distribute_strategy.extended.update(
|
||||||
self, f, args=args, kwargs=kwargs)
|
self, f, args=args, kwargs=kwargs)
|
||||||
else:
|
else:
|
||||||
replica_context = distribution_strategy_context.get_replica_context()
|
replica_context = ds_context.get_replica_context()
|
||||||
assert replica_context
|
assert replica_context
|
||||||
# We are calling an assign function in replica context.
|
# We are calling an assign function in replica context.
|
||||||
# We reduce the value we want to assign/add/sub. More details about how
|
# We reduce the value we want to assign/add/sub. More details about how
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.distribute import distribute_lib
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.distribute import tpu_strategy
|
from tensorflow.python.distribute import tpu_strategy
|
||||||
|
from tensorflow.python.distribute import tpu_values
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -824,7 +825,7 @@ def _make_replica_local(method, strategy=None):
|
|||||||
name=n, initializer=init, use_resource=True))
|
name=n, initializer=init, use_resource=True))
|
||||||
|
|
||||||
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
|
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
|
||||||
var_cls = values.TPUSyncOnReadVariable
|
var_cls = tpu_values.TPUSyncOnReadVariable
|
||||||
else:
|
else:
|
||||||
var_cls = values.SyncOnReadVariable
|
var_cls = values.SyncOnReadVariable
|
||||||
replica_local = var_cls(strategy, v, method)
|
replica_local = var_cls(strategy, v, method)
|
||||||
|
@ -28,6 +28,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/compat:v2_compat",
|
"//tensorflow/python/compat:v2_compat",
|
||||||
|
"//tensorflow/python/distribute:tpu_values",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -26,6 +26,7 @@ from absl.testing import parameterized
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.distribute import tpu_values
|
||||||
from tensorflow.python.distribute import values as distributed_values
|
from tensorflow.python.distribute import values as distributed_values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -249,10 +250,8 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
|
|||||||
def test_supports_distributed_variables(self):
|
def test_supports_distributed_variables(self):
|
||||||
mirrored = distributed_values.MirroredVariable(
|
mirrored = distributed_values.MirroredVariable(
|
||||||
None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
|
None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
|
||||||
tpu = distributed_values.TPUMirroredVariable(
|
tpu = tpu_values.TPUMirroredVariable(
|
||||||
strategy=None,
|
strategy=None, values=[variables.Variable(42.)], aggregation=None)
|
||||||
values=[variables.Variable(42.)],
|
|
||||||
aggregation=None)
|
|
||||||
aggregating = distributed_values.AggregatingVariable(
|
aggregating = distributed_values.AggregatingVariable(
|
||||||
strategy=None, v=variables.Variable(1.), aggregation=None)
|
strategy=None, v=variables.Variable(1.), aggregation=None)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user