From 428ce93ee46089896c0fcfc9460bd0691e57a0e9 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Wed, 3 Feb 2021 09:32:15 -0800 Subject: [PATCH] Move enclosing_tpu_context to a separate util file This is part of the variable refactor work to avoid dependency cycles. PiperOrigin-RevId: 355414142 Change-Id: I36651a7be6462c198aae477923bc2ef0f7e7d0fb --- tensorflow/python/distribute/BUILD | 12 +- tensorflow/python/distribute/tpu_strategy.py | 9 +- tensorflow/python/distribute/tpu_util.py | 35 +++++ tensorflow/python/distribute/tpu_values.py | 156 +++++++++---------- 4 files changed, 125 insertions(+), 87 deletions(-) create mode 100644 tensorflow/python/distribute/tpu_util.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index b0d14e1bab3..e797d708032 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -299,6 +299,15 @@ py_library( ], ) +py_library( + name = "tpu_util", + srcs = ["tpu_util.py"], + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python/tpu:tpu_py", + ], +) + py_library( name = "mirrored_strategy", srcs = ["mirrored_strategy.py"], @@ -579,6 +588,7 @@ py_library( ":input_lib", ":numpy_dataset", ":reduce_util", + ":tpu_util", ":tpu_values", ":values", "//tensorflow/compiler/xla/experimental/xla_sharding", @@ -809,6 +819,7 @@ py_library( srcs_version = "PY3", deps = [ ":packed_distributed_variable", + ":tpu_util", ":values", ":values_util", "//tensorflow/python:framework_ops", @@ -816,7 +827,6 @@ py_library( "//tensorflow/python:resource_variable_ops_gen", "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", - "//tensorflow/python/tpu:tpu_py", ], ) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 4d6b0a204e2..b611b2a8cd8 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -38,6 +38,7 @@ from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import tpu_util from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver @@ -1099,7 +1100,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._logical_device_stack.append(logical_device_id) try: - if tpu_values.enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: yield else: with ops.device(tpu.core(logical_device_id)): @@ -1213,7 +1214,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): def _reduce_to(self, reduce_op, value, destinations, options): if (isinstance(value, values.DistributedValues) or tensor_util.is_tf_type(value) - ) and tpu_values.enclosing_tpu_context() is not None: + ) and tpu_util.enclosing_tpu_context() is not None: if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) @@ -1260,7 +1261,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): def _update(self, var, fn, args, kwargs, group): assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) - if tpu_values.enclosing_tpu_context() is not None: + if tpu_util.enclosing_tpu_context() is not None: if group: return fn(var, *args, **kwargs) else: @@ -1317,7 +1318,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): # since the `1` gets broadcast as an int32 but global_step is int64. if isinstance(tensor, (float, int)): return tensor - if tpu_values.enclosing_tpu_context() is not None: + if tpu_util.enclosing_tpu_context() is not None: broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] result = tpu_ops.all_to_all( broadcast_tensor, diff --git a/tensorflow/python/distribute/tpu_util.py b/tensorflow/python/distribute/tpu_util.py new file mode 100644 index 00000000000..4a8e2d35bab --- /dev/null +++ b/tensorflow/python/distribute/tpu_util.py @@ -0,0 +1,35 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for TPU.""" + +from tensorflow.python.framework import ops +from tensorflow.python.tpu import tpu + + +def enclosing_tpu_context(): + """Returns the TPUReplicateContext, which exists inside a tpu.rewrite().""" + graph = ops.get_default_graph() + while graph is not None: + # pylint: disable=protected-access + context_ = graph._get_control_flow_context() + # pylint: enable=protected-access + while context_ is not None: + if isinstance(context_, tpu.TPUReplicateContext): + return context_ + context_ = context_.outer_context + # This may be a FuncGraph due to defuns or v2 control flow. We need to + # find the original graph with the XLAControlFlowContext. + graph = getattr(graph, "outer_graph", None) + return None diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py index dbe1e1ff363..34ed82251e9 100644 --- a/tensorflow/python/distribute/tpu_values.py +++ b/tensorflow/python/distribute/tpu_values.py @@ -25,6 +25,7 @@ from __future__ import print_function import contextlib from tensorflow.python.distribute import packed_distributed_variable as packed +from tensorflow.python.distribute import tpu_util from tensorflow.python.distribute import values from tensorflow.python.distribute import values_util from tensorflow.python.eager import context @@ -33,7 +34,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.tpu import tpu @contextlib.contextmanager @@ -66,9 +66,7 @@ def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring handle = var.handle with _maybe_enter_graph(handle), _maybe_on_device(var): op = raw_assign_fn( - handle, - ops.convert_to_tensor(value, dtype=var.dtype), - name=name) + handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) with ops.control_dependencies([op]): return var._read_variable_op() if read_value else op # pylint: disable=protected-access @@ -89,14 +87,14 @@ class TPUVariableMixin(object): self._handle_id = self._common_name def __getattr__(self, name): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self).__getattr__(name) else: raise AttributeError( "'{}' not accessible within a TPU context.".format(name)) def get(self): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self).get() else: raise NotImplementedError( @@ -113,7 +111,7 @@ class TPUVariableMixin(object): def handle(self): """The handle by which this variable can be accessed.""" # If we're in a tpu.rewrite(), return the replicated handle. - tpu_context = enclosing_tpu_context() + tpu_context = tpu_util.enclosing_tpu_context() if tpu_context is None or context.executing_eagerly(): var = self._get_on_device_or_primary() if isinstance(var, packed.PackedVarAndDevice): @@ -148,19 +146,19 @@ class TPUVariableMixin(object): return gen_resource_variable_ops.read_variable_op(handle, self.dtype) def read_value(self): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self).read_value() else: return self._read_variable_op() def value(self): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self).value() else: return self._read_variable_op() def _as_graph_element(self): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access else: return None @@ -177,7 +175,7 @@ class TPUVariableMixin(object): def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" # pylint: disable=protected-access - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self)._dense_var_to_tensor( dtype=dtype, name=name, as_ref=as_ref) # pylint: enable=protected-access @@ -187,23 +185,6 @@ class TPUVariableMixin(object): return self.handle if as_ref else self.read_value() -def enclosing_tpu_context(): - """Returns the TPUReplicateContext, which exists inside a tpu.rewrite().""" - graph = ops.get_default_graph() - while graph is not None: - # pylint: disable=protected-access - context_ = graph._get_control_flow_context() - # pylint: enable=protected-access - while context_ is not None: - if isinstance(context_, tpu.TPUReplicateContext): - return context_ - context_ = context_.outer_context - # This may be a FuncGraph due to defuns or v2 control flow. We need to - # find the original graph with the XLAControlFlowContext. - graph = getattr(graph, "outer_graph", None) - return None - - class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable): """DistributedVariable subclass for TPUStrategy.""" @@ -274,9 +255,8 @@ class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable): class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): """Holds a map from replica to TPU variables whose values are kept in sync.""" - def assign_sub(self, value, use_locking=False, name=None, - read_value=True): - if (enclosing_tpu_context() and + def assign_sub(self, value, use_locking=False, name=None, read_value=True): + if (tpu_util.enclosing_tpu_context() and self.aggregation == variable_scope.VariableAggregation.NONE): return _make_raw_assign_fn( gen_resource_variable_ops.assign_sub_variable_op)( @@ -285,12 +265,11 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): use_locking=use_locking, name=name, read_value=read_value) - return assign_sub(self, value, use_locking=use_locking, name=name, - read_value=read_value) + return assign_sub( + self, value, use_locking=use_locking, name=name, read_value=read_value) - def assign_add(self, value, use_locking=False, name=None, - read_value=True): - if (enclosing_tpu_context() and + def assign_add(self, value, use_locking=False, name=None, read_value=True): + if (tpu_util.enclosing_tpu_context() and self.aggregation == variable_scope.VariableAggregation.NONE): return _make_raw_assign_fn( gen_resource_variable_ops.assign_add_variable_op)( @@ -299,21 +278,20 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): use_locking=use_locking, name=name, read_value=read_value) - return assign_add(self, value, use_locking=use_locking, name=name, - read_value=read_value) + return assign_add( + self, value, use_locking=use_locking, name=name, read_value=read_value) def assign(self, value, use_locking=False, name=None, read_value=True): - if (enclosing_tpu_context() and + if (tpu_util.enclosing_tpu_context() and self.aggregation == variable_scope.VariableAggregation.NONE): - return _make_raw_assign_fn( - gen_resource_variable_ops.assign_variable_op)( - self, - value=value, - use_locking=use_locking, - name=name, - read_value=read_value) - return assign(self, value, use_locking=use_locking, name=name, - read_value=read_value) + return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( + self, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) + return assign( + self, value, use_locking=use_locking, name=name, read_value=read_value) def scatter_sub(self, *args, **kwargs): if values_util.is_saving_non_distributed(): @@ -358,7 +336,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): """Holds a map from replica to variables whose values are reduced on save.""" def assign_sub(self, *args, **kwargs): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs) else: return _make_raw_assign_fn( @@ -366,7 +344,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): **kwargs) def assign_add(self, *args, **kwargs): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return values.SyncOnReadVariable.assign_add(self, *args, **kwargs) else: return _make_raw_assign_fn( @@ -374,7 +352,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): **kwargs) def assign(self, *args, **kwargs): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return values.SyncOnReadVariable.assign(self, *args, **kwargs) else: return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( @@ -408,8 +386,7 @@ def assign_add(var, value, use_locking=False, name=None, read_value=True): def assign(var, value, use_locking=False, name=None, read_value=True): - assign_fn = _make_raw_assign_fn( - gen_resource_variable_ops.assign_variable_op) + assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op) return var._update( # pylint: disable=protected-access update_fn=assign_fn, value=value, @@ -427,9 +404,13 @@ class TPUAutoPolicy(values.AutoPolicy): scope. """ - def assign_sub(self, var, value, use_locking=False, name=None, + def assign_sub(self, + var, + value, + use_locking=False, + name=None, read_value=True): - if enclosing_tpu_context(): + if tpu_util.enclosing_tpu_context(): return _make_raw_assign_fn( gen_resource_variable_ops.assign_sub_variable_op)( var, @@ -437,12 +418,16 @@ class TPUAutoPolicy(values.AutoPolicy): use_locking=use_locking, name=name, read_value=read_value) - return assign_sub(var, value, use_locking=use_locking, name=name, - read_value=read_value) + return assign_sub( + var, value, use_locking=use_locking, name=name, read_value=read_value) - def assign_add(self, var, value, use_locking=False, name=None, + def assign_add(self, + var, + value, + use_locking=False, + name=None, read_value=True): - if enclosing_tpu_context(): + if tpu_util.enclosing_tpu_context(): return _make_raw_assign_fn( gen_resource_variable_ops.assign_add_variable_op)( var, @@ -450,20 +435,19 @@ class TPUAutoPolicy(values.AutoPolicy): use_locking=use_locking, name=name, read_value=read_value) - return assign_add(var, value, use_locking=use_locking, name=name, - read_value=read_value) + return assign_add( + var, value, use_locking=use_locking, name=name, read_value=read_value) def assign(self, var, value, use_locking=False, name=None, read_value=True): - if enclosing_tpu_context(): - return _make_raw_assign_fn( - gen_resource_variable_ops.assign_variable_op)( - var, - value=value, - use_locking=use_locking, - name=name, - read_value=read_value) - return assign(var, value, use_locking=use_locking, name=name, - read_value=read_value) + if tpu_util.enclosing_tpu_context(): + return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( + var, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) + return assign( + var, value, use_locking=use_locking, name=name, read_value=read_value) def scatter_sub(self, *args, **kwargs): raise NotImplementedError @@ -504,19 +488,27 @@ class TPUOnWritePolicy(values.OnWritePolicy): values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`. """ - def assign_sub(self, var, value, use_locking=False, name=None, + def assign_sub(self, + var, + value, + use_locking=False, + name=None, read_value=True): - return assign_sub(var, value, use_locking=use_locking, name=name, - read_value=read_value) + return assign_sub( + var, value, use_locking=use_locking, name=name, read_value=read_value) - def assign_add(self, var, value, use_locking=False, name=None, + def assign_add(self, + var, + value, + use_locking=False, + name=None, read_value=True): - return assign_add(var, value, use_locking=use_locking, name=name, - read_value=read_value) + return assign_add( + var, value, use_locking=use_locking, name=name, read_value=read_value) def assign(self, var, value, use_locking=False, name=None, read_value=True): - return assign(var, value, use_locking=use_locking, name=name, - read_value=read_value) + return assign( + var, value, use_locking=use_locking, name=name, read_value=read_value) def scatter_sub(self, *args, **kwargs): raise NotImplementedError @@ -554,7 +546,7 @@ class TPUOnReadPolicy(values.OnReadPolicy): """ def assign_sub(self, var, *args, **kwargs): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs) else: return _make_raw_assign_fn( @@ -562,7 +554,7 @@ class TPUOnReadPolicy(values.OnReadPolicy): **kwargs) def assign_add(self, var, *args, **kwargs): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs) else: return _make_raw_assign_fn( @@ -570,7 +562,7 @@ class TPUOnReadPolicy(values.OnReadPolicy): **kwargs) def assign(self, var, *args, **kwargs): - if enclosing_tpu_context() is None: + if tpu_util.enclosing_tpu_context() is None: return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs) else: return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(