Move enclosing_tpu_context to a separate util file

This is part of the variable refactor work to avoid dependency cycles.

PiperOrigin-RevId: 355414142
Change-Id: I36651a7be6462c198aae477923bc2ef0f7e7d0fb
This commit is contained in:
Ran Chen 2021-02-03 09:32:15 -08:00 committed by TensorFlower Gardener
parent 0d7fe54134
commit 428ce93ee4
4 changed files with 125 additions and 87 deletions

View File

@ -299,6 +299,15 @@ py_library(
], ],
) )
py_library(
name = "tpu_util",
srcs = ["tpu_util.py"],
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python/tpu:tpu_py",
],
)
py_library( py_library(
name = "mirrored_strategy", name = "mirrored_strategy",
srcs = ["mirrored_strategy.py"], srcs = ["mirrored_strategy.py"],
@ -579,6 +588,7 @@ py_library(
":input_lib", ":input_lib",
":numpy_dataset", ":numpy_dataset",
":reduce_util", ":reduce_util",
":tpu_util",
":tpu_values", ":tpu_values",
":values", ":values",
"//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/compiler/xla/experimental/xla_sharding",
@ -809,6 +819,7 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":packed_distributed_variable", ":packed_distributed_variable",
":tpu_util",
":values", ":values",
":values_util", ":values_util",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
@ -816,7 +827,6 @@ py_library(
"//tensorflow/python:resource_variable_ops_gen", "//tensorflow/python:resource_variable_ops_gen",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape", "//tensorflow/python/eager:tape",
"//tensorflow/python/tpu:tpu_py",
], ],
) )

View File

@ -38,6 +38,7 @@ from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_util
from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
@ -1099,7 +1100,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._logical_device_stack.append(logical_device_id) self._logical_device_stack.append(logical_device_id)
try: try:
if tpu_values.enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
yield yield
else: else:
with ops.device(tpu.core(logical_device_id)): with ops.device(tpu.core(logical_device_id)):
@ -1213,7 +1214,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
def _reduce_to(self, reduce_op, value, destinations, options): def _reduce_to(self, reduce_op, value, destinations, options):
if (isinstance(value, values.DistributedValues) or if (isinstance(value, values.DistributedValues) or
tensor_util.is_tf_type(value) tensor_util.is_tf_type(value)
) and tpu_values.enclosing_tpu_context() is not None: ) and tpu_util.enclosing_tpu_context() is not None:
if reduce_op == reduce_util.ReduceOp.MEAN: if reduce_op == reduce_util.ReduceOp.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism. # TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self._num_replicas_in_sync) value *= (1. / self._num_replicas_in_sync)
@ -1260,7 +1261,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
def _update(self, var, fn, args, kwargs, group): def _update(self, var, fn, args, kwargs, group):
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
var, resource_variable_ops.BaseResourceVariable) var, resource_variable_ops.BaseResourceVariable)
if tpu_values.enclosing_tpu_context() is not None: if tpu_util.enclosing_tpu_context() is not None:
if group: if group:
return fn(var, *args, **kwargs) return fn(var, *args, **kwargs)
else: else:
@ -1317,7 +1318,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# since the `1` gets broadcast as an int32 but global_step is int64. # since the `1` gets broadcast as an int32 but global_step is int64.
if isinstance(tensor, (float, int)): if isinstance(tensor, (float, int)):
return tensor return tensor
if tpu_values.enclosing_tpu_context() is not None: if tpu_util.enclosing_tpu_context() is not None:
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
result = tpu_ops.all_to_all( result = tpu_ops.all_to_all(
broadcast_tensor, broadcast_tensor,

View File

@ -0,0 +1,35 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for TPU."""
from tensorflow.python.framework import ops
from tensorflow.python.tpu import tpu
def enclosing_tpu_context():
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, tpu.TPUReplicateContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None

View File

@ -25,6 +25,7 @@ from __future__ import print_function
import contextlib import contextlib
from tensorflow.python.distribute import packed_distributed_variable as packed from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import tpu_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.distribute import values_util from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -33,7 +34,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.tpu import tpu
@contextlib.contextmanager @contextlib.contextmanager
@ -66,9 +66,7 @@ def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
handle = var.handle handle = var.handle
with _maybe_enter_graph(handle), _maybe_on_device(var): with _maybe_enter_graph(handle), _maybe_on_device(var):
op = raw_assign_fn( op = raw_assign_fn(
handle, handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
ops.convert_to_tensor(value, dtype=var.dtype),
name=name)
with ops.control_dependencies([op]): with ops.control_dependencies([op]):
return var._read_variable_op() if read_value else op # pylint: disable=protected-access return var._read_variable_op() if read_value else op # pylint: disable=protected-access
@ -89,14 +87,14 @@ class TPUVariableMixin(object):
self._handle_id = self._common_name self._handle_id = self._common_name
def __getattr__(self, name): def __getattr__(self, name):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).__getattr__(name) return super(TPUVariableMixin, self).__getattr__(name)
else: else:
raise AttributeError( raise AttributeError(
"'{}' not accessible within a TPU context.".format(name)) "'{}' not accessible within a TPU context.".format(name))
def get(self): def get(self):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).get() return super(TPUVariableMixin, self).get()
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -113,7 +111,7 @@ class TPUVariableMixin(object):
def handle(self): def handle(self):
"""The handle by which this variable can be accessed.""" """The handle by which this variable can be accessed."""
# If we're in a tpu.rewrite(), return the replicated handle. # If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = enclosing_tpu_context() tpu_context = tpu_util.enclosing_tpu_context()
if tpu_context is None or context.executing_eagerly(): if tpu_context is None or context.executing_eagerly():
var = self._get_on_device_or_primary() var = self._get_on_device_or_primary()
if isinstance(var, packed.PackedVarAndDevice): if isinstance(var, packed.PackedVarAndDevice):
@ -148,19 +146,19 @@ class TPUVariableMixin(object):
return gen_resource_variable_ops.read_variable_op(handle, self.dtype) return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
def read_value(self): def read_value(self):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).read_value() return super(TPUVariableMixin, self).read_value()
else: else:
return self._read_variable_op() return self._read_variable_op()
def value(self): def value(self):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self).value() return super(TPUVariableMixin, self).value()
else: else:
return self._read_variable_op() return self._read_variable_op()
def _as_graph_element(self): def _as_graph_element(self):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
else: else:
return None return None
@ -177,7 +175,7 @@ class TPUVariableMixin(object):
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor.""" """Converts a variable to a tensor."""
# pylint: disable=protected-access # pylint: disable=protected-access
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUVariableMixin, self)._dense_var_to_tensor( return super(TPUVariableMixin, self)._dense_var_to_tensor(
dtype=dtype, name=name, as_ref=as_ref) dtype=dtype, name=name, as_ref=as_ref)
# pylint: enable=protected-access # pylint: enable=protected-access
@ -187,23 +185,6 @@ class TPUVariableMixin(object):
return self.handle if as_ref else self.read_value() return self.handle if as_ref else self.read_value()
def enclosing_tpu_context():
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, tpu.TPUReplicateContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None
class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable): class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
"""DistributedVariable subclass for TPUStrategy.""" """DistributedVariable subclass for TPUStrategy."""
@ -274,9 +255,8 @@ class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
"""Holds a map from replica to TPU variables whose values are kept in sync.""" """Holds a map from replica to TPU variables whose values are kept in sync."""
def assign_sub(self, value, use_locking=False, name=None, def assign_sub(self, value, use_locking=False, name=None, read_value=True):
read_value=True): if (tpu_util.enclosing_tpu_context() and
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE): self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn( return _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)( gen_resource_variable_ops.assign_sub_variable_op)(
@ -285,12 +265,11 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
use_locking=use_locking, use_locking=use_locking,
name=name, name=name,
read_value=read_value) read_value=read_value)
return assign_sub(self, value, use_locking=use_locking, name=name, return assign_sub(
read_value=read_value) self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, def assign_add(self, value, use_locking=False, name=None, read_value=True):
read_value=True): if (tpu_util.enclosing_tpu_context() and
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE): self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn( return _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)( gen_resource_variable_ops.assign_add_variable_op)(
@ -299,21 +278,20 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
use_locking=use_locking, use_locking=use_locking,
name=name, name=name,
read_value=read_value) read_value=read_value)
return assign_add(self, value, use_locking=use_locking, name=name, return assign_add(
read_value=read_value) self, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True): def assign(self, value, use_locking=False, name=None, read_value=True):
if (enclosing_tpu_context() and if (tpu_util.enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE): self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn( return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
gen_resource_variable_ops.assign_variable_op)( self,
self, value=value,
value=value, use_locking=use_locking,
use_locking=use_locking, name=name,
name=name, read_value=read_value)
read_value=read_value) return assign(
return assign(self, value, use_locking=use_locking, name=name, self, value, use_locking=use_locking, name=name, read_value=read_value)
read_value=read_value)
def scatter_sub(self, *args, **kwargs): def scatter_sub(self, *args, **kwargs):
if values_util.is_saving_non_distributed(): if values_util.is_saving_non_distributed():
@ -358,7 +336,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
"""Holds a map from replica to variables whose values are reduced on save.""" """Holds a map from replica to variables whose values are reduced on save."""
def assign_sub(self, *args, **kwargs): def assign_sub(self, *args, **kwargs):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs) return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
else: else:
return _make_raw_assign_fn( return _make_raw_assign_fn(
@ -366,7 +344,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
**kwargs) **kwargs)
def assign_add(self, *args, **kwargs): def assign_add(self, *args, **kwargs):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign_add(self, *args, **kwargs) return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
else: else:
return _make_raw_assign_fn( return _make_raw_assign_fn(
@ -374,7 +352,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
**kwargs) **kwargs)
def assign(self, *args, **kwargs): def assign(self, *args, **kwargs):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return values.SyncOnReadVariable.assign(self, *args, **kwargs) return values.SyncOnReadVariable.assign(self, *args, **kwargs)
else: else:
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
@ -408,8 +386,7 @@ def assign_add(var, value, use_locking=False, name=None, read_value=True):
def assign(var, value, use_locking=False, name=None, read_value=True): def assign(var, value, use_locking=False, name=None, read_value=True):
assign_fn = _make_raw_assign_fn( assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)
gen_resource_variable_ops.assign_variable_op)
return var._update( # pylint: disable=protected-access return var._update( # pylint: disable=protected-access
update_fn=assign_fn, update_fn=assign_fn,
value=value, value=value,
@ -427,9 +404,13 @@ class TPUAutoPolicy(values.AutoPolicy):
scope. scope.
""" """
def assign_sub(self, var, value, use_locking=False, name=None, def assign_sub(self,
var,
value,
use_locking=False,
name=None,
read_value=True): read_value=True):
if enclosing_tpu_context(): if tpu_util.enclosing_tpu_context():
return _make_raw_assign_fn( return _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)( gen_resource_variable_ops.assign_sub_variable_op)(
var, var,
@ -437,12 +418,16 @@ class TPUAutoPolicy(values.AutoPolicy):
use_locking=use_locking, use_locking=use_locking,
name=name, name=name,
read_value=read_value) read_value=read_value)
return assign_sub(var, value, use_locking=use_locking, name=name, return assign_sub(
read_value=read_value) var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, var, value, use_locking=False, name=None, def assign_add(self,
var,
value,
use_locking=False,
name=None,
read_value=True): read_value=True):
if enclosing_tpu_context(): if tpu_util.enclosing_tpu_context():
return _make_raw_assign_fn( return _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)( gen_resource_variable_ops.assign_add_variable_op)(
var, var,
@ -450,20 +435,19 @@ class TPUAutoPolicy(values.AutoPolicy):
use_locking=use_locking, use_locking=use_locking,
name=name, name=name,
read_value=read_value) read_value=read_value)
return assign_add(var, value, use_locking=use_locking, name=name, return assign_add(
read_value=read_value) var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, var, value, use_locking=False, name=None, read_value=True): def assign(self, var, value, use_locking=False, name=None, read_value=True):
if enclosing_tpu_context(): if tpu_util.enclosing_tpu_context():
return _make_raw_assign_fn( return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
gen_resource_variable_ops.assign_variable_op)( var,
var, value=value,
value=value, use_locking=use_locking,
use_locking=use_locking, name=name,
name=name, read_value=read_value)
read_value=read_value) return assign(
return assign(var, value, use_locking=use_locking, name=name, var, value, use_locking=use_locking, name=name, read_value=read_value)
read_value=read_value)
def scatter_sub(self, *args, **kwargs): def scatter_sub(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@ -504,19 +488,27 @@ class TPUOnWritePolicy(values.OnWritePolicy):
values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`. values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
""" """
def assign_sub(self, var, value, use_locking=False, name=None, def assign_sub(self,
var,
value,
use_locking=False,
name=None,
read_value=True): read_value=True):
return assign_sub(var, value, use_locking=use_locking, name=name, return assign_sub(
read_value=read_value) var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self, var, value, use_locking=False, name=None, def assign_add(self,
var,
value,
use_locking=False,
name=None,
read_value=True): read_value=True):
return assign_add(var, value, use_locking=use_locking, name=name, return assign_add(
read_value=read_value) var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, var, value, use_locking=False, name=None, read_value=True): def assign(self, var, value, use_locking=False, name=None, read_value=True):
return assign(var, value, use_locking=use_locking, name=name, return assign(
read_value=read_value) var, value, use_locking=use_locking, name=name, read_value=read_value)
def scatter_sub(self, *args, **kwargs): def scatter_sub(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@ -554,7 +546,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
""" """
def assign_sub(self, var, *args, **kwargs): def assign_sub(self, var, *args, **kwargs):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs) return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
else: else:
return _make_raw_assign_fn( return _make_raw_assign_fn(
@ -562,7 +554,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
**kwargs) **kwargs)
def assign_add(self, var, *args, **kwargs): def assign_add(self, var, *args, **kwargs):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs) return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
else: else:
return _make_raw_assign_fn( return _make_raw_assign_fn(
@ -570,7 +562,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
**kwargs) **kwargs)
def assign(self, var, *args, **kwargs): def assign(self, var, *args, **kwargs):
if enclosing_tpu_context() is None: if tpu_util.enclosing_tpu_context() is None:
return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs) return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
else: else:
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(