Move enclosing_tpu_context to a separate util file
This is part of the variable refactor work to avoid dependency cycles. PiperOrigin-RevId: 355414142 Change-Id: I36651a7be6462c198aae477923bc2ef0f7e7d0fb
This commit is contained in:
parent
0d7fe54134
commit
428ce93ee4
tensorflow/python/distribute
@ -299,6 +299,15 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tpu_util",
|
||||||
|
srcs = ["tpu_util.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python/tpu:tpu_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "mirrored_strategy",
|
name = "mirrored_strategy",
|
||||||
srcs = ["mirrored_strategy.py"],
|
srcs = ["mirrored_strategy.py"],
|
||||||
@ -579,6 +588,7 @@ py_library(
|
|||||||
":input_lib",
|
":input_lib",
|
||||||
":numpy_dataset",
|
":numpy_dataset",
|
||||||
":reduce_util",
|
":reduce_util",
|
||||||
|
":tpu_util",
|
||||||
":tpu_values",
|
":tpu_values",
|
||||||
":values",
|
":values",
|
||||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||||
@ -809,6 +819,7 @@ py_library(
|
|||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":packed_distributed_variable",
|
":packed_distributed_variable",
|
||||||
|
":tpu_util",
|
||||||
":values",
|
":values",
|
||||||
":values_util",
|
":values_util",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
@ -816,7 +827,6 @@ py_library(
|
|||||||
"//tensorflow/python:resource_variable_ops_gen",
|
"//tensorflow/python:resource_variable_ops_gen",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:tape",
|
"//tensorflow/python/eager:tape",
|
||||||
"//tensorflow/python/tpu:tpu_py",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.distribute import distribute_utils
|
|||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import numpy_dataset
|
from tensorflow.python.distribute import numpy_dataset
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
|
from tensorflow.python.distribute import tpu_util
|
||||||
from tensorflow.python.distribute import tpu_values
|
from tensorflow.python.distribute import tpu_values
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
||||||
@ -1099,7 +1100,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
|
|
||||||
self._logical_device_stack.append(logical_device_id)
|
self._logical_device_stack.append(logical_device_id)
|
||||||
try:
|
try:
|
||||||
if tpu_values.enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
with ops.device(tpu.core(logical_device_id)):
|
with ops.device(tpu.core(logical_device_id)):
|
||||||
@ -1213,7 +1214,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
def _reduce_to(self, reduce_op, value, destinations, options):
|
def _reduce_to(self, reduce_op, value, destinations, options):
|
||||||
if (isinstance(value, values.DistributedValues) or
|
if (isinstance(value, values.DistributedValues) or
|
||||||
tensor_util.is_tf_type(value)
|
tensor_util.is_tf_type(value)
|
||||||
) and tpu_values.enclosing_tpu_context() is not None:
|
) and tpu_util.enclosing_tpu_context() is not None:
|
||||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
# TODO(jhseu): Revisit once we support model-parallelism.
|
# TODO(jhseu): Revisit once we support model-parallelism.
|
||||||
value *= (1. / self._num_replicas_in_sync)
|
value *= (1. / self._num_replicas_in_sync)
|
||||||
@ -1260,7 +1261,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
def _update(self, var, fn, args, kwargs, group):
|
def _update(self, var, fn, args, kwargs, group):
|
||||||
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
|
assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
|
||||||
var, resource_variable_ops.BaseResourceVariable)
|
var, resource_variable_ops.BaseResourceVariable)
|
||||||
if tpu_values.enclosing_tpu_context() is not None:
|
if tpu_util.enclosing_tpu_context() is not None:
|
||||||
if group:
|
if group:
|
||||||
return fn(var, *args, **kwargs)
|
return fn(var, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -1317,7 +1318,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
# since the `1` gets broadcast as an int32 but global_step is int64.
|
# since the `1` gets broadcast as an int32 but global_step is int64.
|
||||||
if isinstance(tensor, (float, int)):
|
if isinstance(tensor, (float, int)):
|
||||||
return tensor
|
return tensor
|
||||||
if tpu_values.enclosing_tpu_context() is not None:
|
if tpu_util.enclosing_tpu_context() is not None:
|
||||||
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
|
broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
|
||||||
result = tpu_ops.all_to_all(
|
result = tpu_ops.all_to_all(
|
||||||
broadcast_tensor,
|
broadcast_tensor,
|
||||||
|
35
tensorflow/python/distribute/tpu_util.py
Normal file
35
tensorflow/python/distribute/tpu_util.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utility functions for TPU."""
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.tpu import tpu
|
||||||
|
|
||||||
|
|
||||||
|
def enclosing_tpu_context():
|
||||||
|
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
while graph is not None:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
context_ = graph._get_control_flow_context()
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
while context_ is not None:
|
||||||
|
if isinstance(context_, tpu.TPUReplicateContext):
|
||||||
|
return context_
|
||||||
|
context_ = context_.outer_context
|
||||||
|
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
||||||
|
# find the original graph with the XLAControlFlowContext.
|
||||||
|
graph = getattr(graph, "outer_graph", None)
|
||||||
|
return None
|
@ -25,6 +25,7 @@ from __future__ import print_function
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from tensorflow.python.distribute import packed_distributed_variable as packed
|
from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||||
|
from tensorflow.python.distribute import tpu_util
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.distribute import values_util
|
from tensorflow.python.distribute import values_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -33,7 +34,6 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import gen_resource_variable_ops
|
from tensorflow.python.ops import gen_resource_variable_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.tpu import tpu
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -66,9 +66,7 @@ def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
|||||||
handle = var.handle
|
handle = var.handle
|
||||||
with _maybe_enter_graph(handle), _maybe_on_device(var):
|
with _maybe_enter_graph(handle), _maybe_on_device(var):
|
||||||
op = raw_assign_fn(
|
op = raw_assign_fn(
|
||||||
handle,
|
handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
|
||||||
ops.convert_to_tensor(value, dtype=var.dtype),
|
|
||||||
name=name)
|
|
||||||
with ops.control_dependencies([op]):
|
with ops.control_dependencies([op]):
|
||||||
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -89,14 +87,14 @@ class TPUVariableMixin(object):
|
|||||||
self._handle_id = self._common_name
|
self._handle_id = self._common_name
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUVariableMixin, self).__getattr__(name)
|
return super(TPUVariableMixin, self).__getattr__(name)
|
||||||
else:
|
else:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"'{}' not accessible within a TPU context.".format(name))
|
"'{}' not accessible within a TPU context.".format(name))
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUVariableMixin, self).get()
|
return super(TPUVariableMixin, self).get()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -113,7 +111,7 @@ class TPUVariableMixin(object):
|
|||||||
def handle(self):
|
def handle(self):
|
||||||
"""The handle by which this variable can be accessed."""
|
"""The handle by which this variable can be accessed."""
|
||||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||||
tpu_context = enclosing_tpu_context()
|
tpu_context = tpu_util.enclosing_tpu_context()
|
||||||
if tpu_context is None or context.executing_eagerly():
|
if tpu_context is None or context.executing_eagerly():
|
||||||
var = self._get_on_device_or_primary()
|
var = self._get_on_device_or_primary()
|
||||||
if isinstance(var, packed.PackedVarAndDevice):
|
if isinstance(var, packed.PackedVarAndDevice):
|
||||||
@ -148,19 +146,19 @@ class TPUVariableMixin(object):
|
|||||||
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
|
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
|
||||||
|
|
||||||
def read_value(self):
|
def read_value(self):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUVariableMixin, self).read_value()
|
return super(TPUVariableMixin, self).read_value()
|
||||||
else:
|
else:
|
||||||
return self._read_variable_op()
|
return self._read_variable_op()
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUVariableMixin, self).value()
|
return super(TPUVariableMixin, self).value()
|
||||||
else:
|
else:
|
||||||
return self._read_variable_op()
|
return self._read_variable_op()
|
||||||
|
|
||||||
def _as_graph_element(self):
|
def _as_graph_element(self):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
|
return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -177,7 +175,7 @@ class TPUVariableMixin(object):
|
|||||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||||
"""Converts a variable to a tensor."""
|
"""Converts a variable to a tensor."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUVariableMixin, self)._dense_var_to_tensor(
|
return super(TPUVariableMixin, self)._dense_var_to_tensor(
|
||||||
dtype=dtype, name=name, as_ref=as_ref)
|
dtype=dtype, name=name, as_ref=as_ref)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
@ -187,23 +185,6 @@ class TPUVariableMixin(object):
|
|||||||
return self.handle if as_ref else self.read_value()
|
return self.handle if as_ref else self.read_value()
|
||||||
|
|
||||||
|
|
||||||
def enclosing_tpu_context():
|
|
||||||
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
|
|
||||||
graph = ops.get_default_graph()
|
|
||||||
while graph is not None:
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
context_ = graph._get_control_flow_context()
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
while context_ is not None:
|
|
||||||
if isinstance(context_, tpu.TPUReplicateContext):
|
|
||||||
return context_
|
|
||||||
context_ = context_.outer_context
|
|
||||||
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
|
||||||
# find the original graph with the XLAControlFlowContext.
|
|
||||||
graph = getattr(graph, "outer_graph", None)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
|
class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
|
||||||
"""DistributedVariable subclass for TPUStrategy."""
|
"""DistributedVariable subclass for TPUStrategy."""
|
||||||
|
|
||||||
@ -274,9 +255,8 @@ class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
|
|||||||
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
||||||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||||
|
|
||||||
def assign_sub(self, value, use_locking=False, name=None,
|
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||||
read_value=True):
|
if (tpu_util.enclosing_tpu_context() and
|
||||||
if (enclosing_tpu_context() and
|
|
||||||
self.aggregation == variable_scope.VariableAggregation.NONE):
|
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_sub_variable_op)(
|
gen_resource_variable_ops.assign_sub_variable_op)(
|
||||||
@ -285,12 +265,11 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
|||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
name=name,
|
name=name,
|
||||||
read_value=read_value)
|
read_value=read_value)
|
||||||
return assign_sub(self, value, use_locking=use_locking, name=name,
|
return assign_sub(
|
||||||
read_value=read_value)
|
self, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
|
|
||||||
def assign_add(self, value, use_locking=False, name=None,
|
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||||
read_value=True):
|
if (tpu_util.enclosing_tpu_context() and
|
||||||
if (enclosing_tpu_context() and
|
|
||||||
self.aggregation == variable_scope.VariableAggregation.NONE):
|
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_add_variable_op)(
|
gen_resource_variable_ops.assign_add_variable_op)(
|
||||||
@ -299,21 +278,20 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
|||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
name=name,
|
name=name,
|
||||||
read_value=read_value)
|
read_value=read_value)
|
||||||
return assign_add(self, value, use_locking=use_locking, name=name,
|
return assign_add(
|
||||||
read_value=read_value)
|
self, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
|
|
||||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||||
if (enclosing_tpu_context() and
|
if (tpu_util.enclosing_tpu_context() and
|
||||||
self.aggregation == variable_scope.VariableAggregation.NONE):
|
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||||
gen_resource_variable_ops.assign_variable_op)(
|
self,
|
||||||
self,
|
value=value,
|
||||||
value=value,
|
use_locking=use_locking,
|
||||||
use_locking=use_locking,
|
name=name,
|
||||||
name=name,
|
read_value=read_value)
|
||||||
read_value=read_value)
|
return assign(
|
||||||
return assign(self, value, use_locking=use_locking, name=name,
|
self, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
read_value=read_value)
|
|
||||||
|
|
||||||
def scatter_sub(self, *args, **kwargs):
|
def scatter_sub(self, *args, **kwargs):
|
||||||
if values_util.is_saving_non_distributed():
|
if values_util.is_saving_non_distributed():
|
||||||
@ -358,7 +336,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
|
|||||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
def assign_sub(self, *args, **kwargs):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
|
return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
@ -366,7 +344,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def assign_add(self, *args, **kwargs):
|
def assign_add(self, *args, **kwargs):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
|
return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
@ -374,7 +352,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def assign(self, *args, **kwargs):
|
def assign(self, *args, **kwargs):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return values.SyncOnReadVariable.assign(self, *args, **kwargs)
|
return values.SyncOnReadVariable.assign(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||||
@ -408,8 +386,7 @@ def assign_add(var, value, use_locking=False, name=None, read_value=True):
|
|||||||
|
|
||||||
|
|
||||||
def assign(var, value, use_locking=False, name=None, read_value=True):
|
def assign(var, value, use_locking=False, name=None, read_value=True):
|
||||||
assign_fn = _make_raw_assign_fn(
|
assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)
|
||||||
gen_resource_variable_ops.assign_variable_op)
|
|
||||||
return var._update( # pylint: disable=protected-access
|
return var._update( # pylint: disable=protected-access
|
||||||
update_fn=assign_fn,
|
update_fn=assign_fn,
|
||||||
value=value,
|
value=value,
|
||||||
@ -427,9 +404,13 @@ class TPUAutoPolicy(values.AutoPolicy):
|
|||||||
scope.
|
scope.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def assign_sub(self, var, value, use_locking=False, name=None,
|
def assign_sub(self,
|
||||||
|
var,
|
||||||
|
value,
|
||||||
|
use_locking=False,
|
||||||
|
name=None,
|
||||||
read_value=True):
|
read_value=True):
|
||||||
if enclosing_tpu_context():
|
if tpu_util.enclosing_tpu_context():
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_sub_variable_op)(
|
gen_resource_variable_ops.assign_sub_variable_op)(
|
||||||
var,
|
var,
|
||||||
@ -437,12 +418,16 @@ class TPUAutoPolicy(values.AutoPolicy):
|
|||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
name=name,
|
name=name,
|
||||||
read_value=read_value)
|
read_value=read_value)
|
||||||
return assign_sub(var, value, use_locking=use_locking, name=name,
|
return assign_sub(
|
||||||
read_value=read_value)
|
var, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
|
|
||||||
def assign_add(self, var, value, use_locking=False, name=None,
|
def assign_add(self,
|
||||||
|
var,
|
||||||
|
value,
|
||||||
|
use_locking=False,
|
||||||
|
name=None,
|
||||||
read_value=True):
|
read_value=True):
|
||||||
if enclosing_tpu_context():
|
if tpu_util.enclosing_tpu_context():
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_add_variable_op)(
|
gen_resource_variable_ops.assign_add_variable_op)(
|
||||||
var,
|
var,
|
||||||
@ -450,20 +435,19 @@ class TPUAutoPolicy(values.AutoPolicy):
|
|||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
name=name,
|
name=name,
|
||||||
read_value=read_value)
|
read_value=read_value)
|
||||||
return assign_add(var, value, use_locking=use_locking, name=name,
|
return assign_add(
|
||||||
read_value=read_value)
|
var, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
|
|
||||||
def assign(self, var, value, use_locking=False, name=None, read_value=True):
|
def assign(self, var, value, use_locking=False, name=None, read_value=True):
|
||||||
if enclosing_tpu_context():
|
if tpu_util.enclosing_tpu_context():
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||||
gen_resource_variable_ops.assign_variable_op)(
|
var,
|
||||||
var,
|
value=value,
|
||||||
value=value,
|
use_locking=use_locking,
|
||||||
use_locking=use_locking,
|
name=name,
|
||||||
name=name,
|
read_value=read_value)
|
||||||
read_value=read_value)
|
return assign(
|
||||||
return assign(var, value, use_locking=use_locking, name=name,
|
var, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
read_value=read_value)
|
|
||||||
|
|
||||||
def scatter_sub(self, *args, **kwargs):
|
def scatter_sub(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -504,19 +488,27 @@ class TPUOnWritePolicy(values.OnWritePolicy):
|
|||||||
values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
|
values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def assign_sub(self, var, value, use_locking=False, name=None,
|
def assign_sub(self,
|
||||||
|
var,
|
||||||
|
value,
|
||||||
|
use_locking=False,
|
||||||
|
name=None,
|
||||||
read_value=True):
|
read_value=True):
|
||||||
return assign_sub(var, value, use_locking=use_locking, name=name,
|
return assign_sub(
|
||||||
read_value=read_value)
|
var, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
|
|
||||||
def assign_add(self, var, value, use_locking=False, name=None,
|
def assign_add(self,
|
||||||
|
var,
|
||||||
|
value,
|
||||||
|
use_locking=False,
|
||||||
|
name=None,
|
||||||
read_value=True):
|
read_value=True):
|
||||||
return assign_add(var, value, use_locking=use_locking, name=name,
|
return assign_add(
|
||||||
read_value=read_value)
|
var, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
|
|
||||||
def assign(self, var, value, use_locking=False, name=None, read_value=True):
|
def assign(self, var, value, use_locking=False, name=None, read_value=True):
|
||||||
return assign(var, value, use_locking=use_locking, name=name,
|
return assign(
|
||||||
read_value=read_value)
|
var, value, use_locking=use_locking, name=name, read_value=read_value)
|
||||||
|
|
||||||
def scatter_sub(self, *args, **kwargs):
|
def scatter_sub(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -554,7 +546,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def assign_sub(self, var, *args, **kwargs):
|
def assign_sub(self, var, *args, **kwargs):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
|
return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
@ -562,7 +554,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def assign_add(self, var, *args, **kwargs):
|
def assign_add(self, var, *args, **kwargs):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
|
return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return _make_raw_assign_fn(
|
return _make_raw_assign_fn(
|
||||||
@ -570,7 +562,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def assign(self, var, *args, **kwargs):
|
def assign(self, var, *args, **kwargs):
|
||||||
if enclosing_tpu_context() is None:
|
if tpu_util.enclosing_tpu_context() is None:
|
||||||
return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
|
return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||||
|
Loading…
Reference in New Issue
Block a user