Refactor values.py into a utility file and a PS values file.
PiperOrigin-RevId: 313617630 Change-Id: Ie51b0f69af65b3f85701f58190da2c7eb46e1d29
This commit is contained in:
parent
393e92ae5f
commit
8d31fb4b76
@ -67,6 +67,7 @@ py_library(
|
||||
":collective_util",
|
||||
":cross_device_utils",
|
||||
":device_util",
|
||||
":ps_values",
|
||||
":reduce_util",
|
||||
":tpu_values",
|
||||
":values",
|
||||
@ -78,7 +79,9 @@ py_library(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:executor",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -315,18 +318,23 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":cross_device_ops",
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":input_lib",
|
||||
":mirrored_run",
|
||||
":multi_worker_util",
|
||||
":numpy_dataset",
|
||||
":reduce_util",
|
||||
":ps_values",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
@ -671,12 +679,14 @@ py_library(
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":reduce_util",
|
||||
":values_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
@ -686,6 +696,34 @@ py_library(
|
||||
"//tensorflow/python/training/saving:saveable_object",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/types",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ps_values",
|
||||
srcs = ["ps_values.py"],
|
||||
deps = [
|
||||
":distribute_lib",
|
||||
":values",
|
||||
":values_util",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/types",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "values_util",
|
||||
srcs = ["values_util.py"],
|
||||
deps = [
|
||||
":distribute_lib",
|
||||
":reduce_util",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1037,23 +1075,57 @@ distribute_py_test(
|
||||
],
|
||||
deps = [
|
||||
":combinations",
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":mirrored_strategy",
|
||||
":parameter_server_strategy",
|
||||
":strategy_combinations",
|
||||
":tpu_strategy",
|
||||
":tpu_values",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:indexed_slices",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/saved_model/model_utils:mode_keys",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
"//tensorflow/python/types",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "ps_values_test",
|
||||
size = "medium",
|
||||
srcs = ["ps_values_test.py"],
|
||||
main = "ps_values_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
":combinations",
|
||||
":ps_values",
|
||||
":strategy_combinations",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
@ -1570,10 +1642,13 @@ cuda_py_test(
|
||||
deps = [
|
||||
":central_storage_strategy",
|
||||
":combinations",
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":multi_worker_test_base",
|
||||
":multi_worker_util",
|
||||
":parameter_server_strategy",
|
||||
":strategy_combinations",
|
||||
":ps_values",
|
||||
":reduce_util",
|
||||
":strategy_test_lib",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
@ -1581,16 +1656,22 @@ cuda_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:partitioned_variables",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:training_util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//tensorflow/python/estimator:run_config",
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -19,15 +19,16 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import enum
|
||||
import threading
|
||||
|
||||
import enum
|
||||
import six
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
@ -64,7 +65,7 @@ def validate_destinations(destinations):
|
||||
"""Validates the `destination` is one of expected types."""
|
||||
if not isinstance(
|
||||
destinations,
|
||||
(value_lib.DistributedValues, ops.Tensor, value_lib.AggregatingVariable,
|
||||
(value_lib.DistributedValues, ops.Tensor, ps_values.AggregatingVariable,
|
||||
six.string_types, tpu_values.TPUMirroredVariable
|
||||
)) and not resource_variable_ops.is_resource_variable(destinations):
|
||||
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import mirrored_run
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import numpy_dataset
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
@ -441,8 +442,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
# Create and wrap the variable.
|
||||
v = next_creator(**kwargs)
|
||||
wrapped = values.AggregatingVariable(
|
||||
self._container_strategy(), v, aggregation)
|
||||
wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
|
||||
aggregation)
|
||||
|
||||
# Add the wrapped variable to the requested collections.
|
||||
# The handling of eager mode and the global step matches
|
||||
@ -539,7 +540,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
return nest.map_structure(_select_fn, structured)
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
if isinstance(var, values.AggregatingVariable):
|
||||
if isinstance(var, ps_values.AggregatingVariable):
|
||||
var = var.get()
|
||||
if not resource_variable_ops.is_resource_variable(var):
|
||||
raise ValueError(
|
||||
@ -569,7 +570,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
def value_container(self, val):
|
||||
if (hasattr(val, "_aggregating_container") and
|
||||
not isinstance(val, values.AggregatingVariable)):
|
||||
not isinstance(val, ps_values.AggregatingVariable)):
|
||||
wrapper = val._aggregating_container() # pylint: disable=protected-access
|
||||
if wrapper is not None:
|
||||
return wrapper
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.distribute import distribution_strategy_context as ds_con
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.distribute import values
|
||||
@ -796,8 +797,8 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
||||
msg=('created_step %s type %s vs. get_step %s type %s' %
|
||||
(id(created_step), created_step.__class__.__name__,
|
||||
id(get_step), get_step.__class__.__name__)))
|
||||
self.assertIs(values.AggregatingVariable, type(created_step))
|
||||
self.assertIs(values.AggregatingVariable, type(get_step))
|
||||
self.assertIs(ps_values.AggregatingVariable, type(created_step))
|
||||
self.assertIs(ps_values.AggregatingVariable, type(get_step))
|
||||
self.assertIs(strategy, created_step.distribute_strategy)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph']))
|
||||
@ -828,7 +829,7 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
||||
_ = v * v
|
||||
v, = tape.watched_variables()
|
||||
w = strategy.extended.value_container(v)
|
||||
self.assertIs(values.AggregatingVariable, type(w))
|
||||
self.assertIs(ps_values.AggregatingVariable, type(w))
|
||||
|
||||
strategy.extended.call_for_each_replica(f)
|
||||
|
||||
|
304
tensorflow/python/distribute/ps_values.py
Normal file
304
tensorflow/python/distribute/ps_values.py
Normal file
@ -0,0 +1,304 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Various classes representing distributed values for PS."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute import values_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.types import core
|
||||
|
||||
|
||||
# Variable used in PSStrategy TF 1 and CentralStorageStrategy.
|
||||
class AggregatingVariable(variables_lib.Variable, core.Tensor):
|
||||
"""A wrapper around a variable that aggregates updates across replicas."""
|
||||
|
||||
def __init__(self, strategy, v, aggregation):
|
||||
self._distribute_strategy = strategy
|
||||
self._v = v
|
||||
# NOTE: We don't use "_distributed_container" here because we don't want
|
||||
# to trigger that code path in regroup().
|
||||
v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
self._aggregation = aggregation
|
||||
|
||||
def get(self):
|
||||
return self._v
|
||||
|
||||
@property
|
||||
def distribute_strategy(self):
|
||||
return self._distribute_strategy
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._v, name)
|
||||
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
f = kwargs.pop("f")
|
||||
if ds_context.in_cross_replica_context():
|
||||
if distribute_lib.get_update_replica_id() is not None:
|
||||
# We are calling an assign function in an update context.
|
||||
return f(self._v, *args, **kwargs)
|
||||
|
||||
# We are calling an assign function in cross replica context, wrap it in
|
||||
# an update call.
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
else:
|
||||
replica_context = ds_context.get_replica_context()
|
||||
assert replica_context
|
||||
# We are calling an assign function in replica context.
|
||||
# We reduce the value we want to assign/add/sub. More details about how
|
||||
# we handle the different use cases can be found in the _reduce method.
|
||||
# We call the function with the reduced value.
|
||||
if self._aggregation == vs.VariableAggregation.NONE:
|
||||
raise ValueError(
|
||||
values_util.aggregation_error_msg.format(
|
||||
variable_type="AggregatingVariable"))
|
||||
|
||||
def merge_fn(strategy,
|
||||
value,
|
||||
use_locking=False,
|
||||
name=None,
|
||||
read_value=True):
|
||||
v = values_util.apply_aggregation(strategy, value, self._aggregation,
|
||||
self)
|
||||
if name and isinstance(name, values.PerReplica):
|
||||
name = name.values[0]
|
||||
return strategy.extended.update(
|
||||
self,
|
||||
f,
|
||||
args=(v,),
|
||||
kwargs={
|
||||
"use_locking": use_locking,
|
||||
"name": name,
|
||||
"read_value": read_value
|
||||
})
|
||||
return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
return self._v.initializer
|
||||
|
||||
def initialized_value(self):
|
||||
return self._v.initialized_value()
|
||||
|
||||
@property
|
||||
def initial_value(self):
|
||||
return self._v.initial_value
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return self._v.op
|
||||
|
||||
def read_value(self):
|
||||
return self._v.read_value()
|
||||
|
||||
def eval(self, session=None):
|
||||
return self._v.eval(session)
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
return self._v.graph
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._v.device
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._v.shape
|
||||
|
||||
@property
|
||||
def aggregation(self):
|
||||
return self._aggregation
|
||||
|
||||
@property
|
||||
def synchronization(self):
|
||||
return self._v.synchronization
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._v.name
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
return self._v.trainable
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._v.dtype
|
||||
|
||||
# TODO(josh11b): Test saving & restoring.
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
return {trackable.VARIABLE_VALUE_KEY: self._v}
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o):
|
||||
return self._v + o
|
||||
|
||||
def __radd__(self, o):
|
||||
return o + self._v
|
||||
|
||||
def __sub__(self, o):
|
||||
return self._v - o
|
||||
|
||||
def __rsub__(self, o):
|
||||
return o - self._v
|
||||
|
||||
def __mul__(self, o):
|
||||
return self._v * o
|
||||
|
||||
def __rmul__(self, o):
|
||||
return o * self._v
|
||||
|
||||
def __truediv__(self, o):
|
||||
return self._v / o
|
||||
|
||||
def __rtruediv__(self, o):
|
||||
return o / self._v
|
||||
|
||||
def __floordiv__(self, o):
|
||||
return self._v // o
|
||||
|
||||
def __rfloordiv__(self, o):
|
||||
return o // self._v
|
||||
|
||||
def __mod__(self, o):
|
||||
return self._v % o
|
||||
|
||||
def __rmod__(self, o):
|
||||
return o % self._v
|
||||
|
||||
def __lt__(self, o):
|
||||
return self._v < o
|
||||
|
||||
def __le__(self, o):
|
||||
return self._v <= o
|
||||
|
||||
def __gt__(self, o):
|
||||
return self._v > o
|
||||
|
||||
def __ge__(self, o):
|
||||
return self._v >= o
|
||||
|
||||
def __and__(self, o):
|
||||
return self._v & o
|
||||
|
||||
def __rand__(self, o):
|
||||
return o & self._v
|
||||
|
||||
def __or__(self, o):
|
||||
return self._v | o
|
||||
|
||||
def __ror__(self, o):
|
||||
return o | self._v
|
||||
|
||||
def __xor__(self, o):
|
||||
return self._v ^ o
|
||||
|
||||
def __rxor__(self, o):
|
||||
return o ^ self._v
|
||||
|
||||
def __getitem__(self, o):
|
||||
return self._v[o]
|
||||
|
||||
def __pow__(self, o, modulo=None):
|
||||
return pow(self._v, o, modulo)
|
||||
|
||||
def __rpow__(self, o):
|
||||
return pow(o, self._v)
|
||||
|
||||
def __invert__(self):
|
||||
return ~self._v
|
||||
|
||||
def __neg__(self):
|
||||
return -self._v
|
||||
|
||||
def __abs__(self):
|
||||
return abs(self._v)
|
||||
|
||||
def __div__(self, o):
|
||||
try:
|
||||
return self._v.__div__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rdiv__(self, o):
|
||||
try:
|
||||
return self._v.__rdiv__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, o):
|
||||
try:
|
||||
return self._v.__matmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, o):
|
||||
try:
|
||||
return self._v.__rmatmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __str__(self):
|
||||
return str(self._v)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self._v)
|
||||
|
||||
def _should_act_as_resource_variable(self):
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
pass
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
return ops.convert_to_tensor(self.get(), dtype=dtype, name=name,
|
||||
as_ref=as_ref)
|
||||
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
|
||||
return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(AggregatingVariable,
|
||||
_tensor_conversion_aggregate)
|
65
tensorflow/python/distribute/ps_values_test.py
Normal file
65
tensorflow/python/distribute/ps_values_test.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the distributed values library."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
class AggregatingVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testAssignOutOfScope(self, distribution):
|
||||
with distribution.scope():
|
||||
aggregating = variables_lib.Variable(1.)
|
||||
self.assertIsInstance(aggregating, ps_values.AggregatingVariable)
|
||||
self.evaluate(aggregating.assign(3.))
|
||||
self.assertEqual(self.evaluate(aggregating.read_value()), 3.)
|
||||
self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.)
|
||||
|
||||
def testAssignAdd(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
1, aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
@def_function.function
|
||||
def assign():
|
||||
return v.assign_add(2)
|
||||
|
||||
per_replica_results = self.evaluate(
|
||||
distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(assign)))
|
||||
self.assertAllEqual([3], per_replica_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -18,12 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
@ -43,72 +43,6 @@ from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# Utility functions used by the different classes below.
|
||||
def _get_current_replica_id_as_int():
|
||||
"""Returns the current replica ID as an integer, or `None`."""
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if replica_context:
|
||||
replica_id = replica_context.replica_id_in_sync_group
|
||||
if not isinstance(replica_id, int):
|
||||
replica_id = tensor_util.constant_value(replica_id)
|
||||
else:
|
||||
replica_id = distribute_lib.get_update_replica_id()
|
||||
return replica_id
|
||||
|
||||
|
||||
def _assign_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign(tensor)
|
||||
|
||||
|
||||
def _assign_add_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign_add(tensor)
|
||||
|
||||
|
||||
def _assign_sub_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign_sub(tensor)
|
||||
|
||||
|
||||
def _assert_replica_context(strategy):
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if not replica_context:
|
||||
raise RuntimeError(
|
||||
"Replica-local variables may only be assigned in a replica context.")
|
||||
if replica_context.strategy is not strategy:
|
||||
raise RuntimeError(
|
||||
"Replica-local variables may only be assigned in a replica context.")
|
||||
|
||||
|
||||
def _apply_aggregation(strategy, value, aggregation, destinations):
|
||||
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return strategy.extended.broadcast_to(
|
||||
strategy.experimental_local_results(value)[0],
|
||||
destinations=destinations)
|
||||
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
|
||||
return strategy.extended.reduce_to(reduce_op, value, destinations)
|
||||
|
||||
|
||||
_aggregation_error_msg = (
|
||||
"You must specify an aggregation method to update a "
|
||||
"{variable_type} in Replica Context. You can do so by passing "
|
||||
"an explicit value for argument `aggregation` to tf.Variable(..)."
|
||||
"e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
|
||||
"`tf.VariableAggregation` lists the possible aggregation methods."
|
||||
"This is required because {variable_type} should always be "
|
||||
"kept in sync. When updating them or assigning to them in a "
|
||||
"replica context, we automatically try to aggregate the values "
|
||||
"before updating the variable. For this aggregation, we need to "
|
||||
"know the aggregation method. "
|
||||
"Another alternative is to not try to update such "
|
||||
"{variable_type} in replica context, but in cross replica "
|
||||
"context. You can enter cross replica context by calling "
|
||||
"`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
|
||||
"Inside `merge_fn`, you can then update the {variable_type} "
|
||||
"using `tf.distribute.StrategyExtended.update()`.")
|
||||
|
||||
|
||||
@tf_export("distribute.DistributedValues", v1=[])
|
||||
class DistributedValues(object):
|
||||
"""Base class for representing distributed values.
|
||||
@ -182,7 +116,7 @@ class DistributedValues(object):
|
||||
|
||||
def _get(self):
|
||||
"""Returns the value for the current device or raises a ValueError."""
|
||||
replica_id = _get_current_replica_id_as_int()
|
||||
replica_id = values_util.get_current_replica_id_as_int()
|
||||
if replica_id is None:
|
||||
return self._get_cross_replica()
|
||||
else:
|
||||
@ -195,7 +129,7 @@ class DistributedValues(object):
|
||||
|
||||
def _get_on_device_or_primary(self):
|
||||
"""Returns value in same replica or device if possible, else the _primary."""
|
||||
replica_id = _get_current_replica_id_as_int()
|
||||
replica_id = values_util.get_current_replica_id_as_int()
|
||||
if replica_id is None:
|
||||
# Try to find a value on the current device.
|
||||
current_device = device_util.canonicalize(device_util.current())
|
||||
@ -568,7 +502,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
replica_id = _get_current_replica_id_as_int()
|
||||
replica_id = values_util.get_current_replica_id_as_int()
|
||||
if replica_id is None:
|
||||
raise ValueError("`handle` is not available outside the replica context"
|
||||
" or a `tf.distribute.Strategy.update()` call.")
|
||||
@ -774,7 +708,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
return update_fn(self._values[update_replica_id], value, **kwargs)
|
||||
return self._update_cross_replica(update_fn, value, **kwargs)
|
||||
else:
|
||||
_assert_replica_context(self.distribute_strategy)
|
||||
values_util.assert_replica_context(self.distribute_strategy)
|
||||
return self._update_replica(update_fn, value, **kwargs)
|
||||
|
||||
def _should_act_as_resource_variable(self):
|
||||
@ -794,7 +728,7 @@ class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
|
||||
tensor, = restored_tensors
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
_assign_on_device(v.device, v, tensor)
|
||||
values_util.assign_on_device(v.device, v, tensor)
|
||||
for v in self._mirrored_variable.values))
|
||||
|
||||
|
||||
@ -804,7 +738,8 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
def _update_replica(self, update_fn, value, **kwargs):
|
||||
if self.aggregation == vs.VariableAggregation.NONE:
|
||||
raise ValueError(
|
||||
_aggregation_error_msg.format(variable_type="MirroredVariable"))
|
||||
values_util.aggregation_error_msg.format(
|
||||
variable_type="MirroredVariable"))
|
||||
|
||||
def merge_fn(strategy, value, **kwargs):
|
||||
"""Aggregate values and update all variables in cross replica context."""
|
||||
@ -824,7 +759,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"cross-replica context.")
|
||||
|
||||
assert strategy == self.distribute_strategy
|
||||
v = _apply_aggregation(strategy, value, self.aggregation, self)
|
||||
v = values_util.apply_aggregation(strategy, value, self.aggregation, self)
|
||||
return self._update_cross_replica(update_fn, v, **kwargs)
|
||||
|
||||
return ds_context.get_replica_context().merge_call(
|
||||
@ -930,7 +865,7 @@ class _SyncOnReadSaveable(saveable_object.SaveableObject):
|
||||
self._sync_on_read_variable.dtype)
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
_assign_on_device(v.device, v, tensor)
|
||||
values_util.assign_on_device(v.device, v, tensor)
|
||||
for v in self._sync_on_read_variable.values))
|
||||
|
||||
|
||||
@ -960,8 +895,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
"SyncOnReadVariable does not support `assign_sub` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return self._assign_on_each_device(_assign_sub_on_device, value,
|
||||
read_value)
|
||||
return self._assign_on_each_device(values_util.assign_sub_on_device,
|
||||
value, read_value)
|
||||
else:
|
||||
return super(SyncOnReadVariable,
|
||||
self).assign_sub(value, use_locking, name, read_value)
|
||||
@ -974,8 +909,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
"SyncOnReadVariable does not support `assign_add` in "
|
||||
"cross-replica context when aggregation is set to "
|
||||
"`tf.VariableAggregation.SUM`.")
|
||||
return self._assign_on_each_device(_assign_add_on_device, value,
|
||||
read_value)
|
||||
return self._assign_on_each_device(values_util.assign_add_on_device,
|
||||
value, read_value)
|
||||
else:
|
||||
return super(SyncOnReadVariable,
|
||||
self).assign_add(value, use_locking, name, read_value)
|
||||
@ -988,7 +923,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
# when saving.
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
value = math_ops.cast(value / len(self._values), self.dtype)
|
||||
return self._assign_on_each_device(_assign_on_device, value,
|
||||
return self._assign_on_each_device(values_util.assign_on_device, value,
|
||||
read_value)
|
||||
else:
|
||||
return super(SyncOnReadVariable,
|
||||
@ -1388,275 +1323,3 @@ def validate_colocate(v, extended):
|
||||
"`colocate_vars_with` must only be passed a variable created in this "
|
||||
"tf.distribute.Strategy.scope(), not: %r" % (v,))
|
||||
_validate_colocate_extended(v, extended)
|
||||
|
||||
|
||||
# Variable used in PSStrategy TF 1 and CentralStorageStrategy.
|
||||
class AggregatingVariable(variables_lib.Variable, core.Tensor):
|
||||
"""A wrapper around a variable that aggregates updates across replicas."""
|
||||
|
||||
def __init__(self, strategy, v, aggregation):
|
||||
self._distribute_strategy = strategy
|
||||
self._v = v
|
||||
# NOTE: We don't use "_distributed_container" here because we don't want
|
||||
# to trigger that code path in regroup().
|
||||
v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
self._aggregation = aggregation
|
||||
|
||||
def get(self):
|
||||
return self._v
|
||||
|
||||
@property
|
||||
def distribute_strategy(self):
|
||||
return self._distribute_strategy
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._v, name)
|
||||
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
f = kwargs.pop("f")
|
||||
if ds_context.in_cross_replica_context():
|
||||
if distribute_lib.get_update_replica_id() is not None:
|
||||
# We are calling an assign function in an update context.
|
||||
return f(self._v, *args, **kwargs)
|
||||
|
||||
# We are calling an assign function in cross replica context, wrap it in
|
||||
# an update call.
|
||||
return self._distribute_strategy.extended.update(
|
||||
self, f, args=args, kwargs=kwargs)
|
||||
else:
|
||||
replica_context = ds_context.get_replica_context()
|
||||
assert replica_context
|
||||
# We are calling an assign function in replica context.
|
||||
# We reduce the value we want to assign/add/sub. More details about how
|
||||
# we handle the different use cases can be found in the _reduce method.
|
||||
# We call the function with the reduced value.
|
||||
if self._aggregation == vs.VariableAggregation.NONE:
|
||||
raise ValueError(
|
||||
_aggregation_error_msg.format(
|
||||
variable_type="AggregatingVariable"))
|
||||
|
||||
def merge_fn(strategy,
|
||||
value,
|
||||
use_locking=False,
|
||||
name=None,
|
||||
read_value=True):
|
||||
v = _apply_aggregation(strategy, value, self._aggregation, self)
|
||||
if name and isinstance(name, PerReplica):
|
||||
name = name.values[0]
|
||||
return strategy.extended.update(
|
||||
self,
|
||||
f,
|
||||
args=(v,),
|
||||
kwargs={
|
||||
"use_locking": use_locking,
|
||||
"name": name,
|
||||
"read_value": read_value
|
||||
})
|
||||
return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
||||
return self._assign_func(f=assign_sub_fn, *args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
||||
return self._assign_func(f=assign_add_fn, *args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
||||
return self._assign_func(f=assign_fn, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
return self._v.initializer
|
||||
|
||||
def initialized_value(self):
|
||||
return self._v.initialized_value()
|
||||
|
||||
@property
|
||||
def initial_value(self):
|
||||
return self._v.initial_value
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return self._v.op
|
||||
|
||||
def read_value(self):
|
||||
return self._v.read_value()
|
||||
|
||||
def eval(self, session=None):
|
||||
return self._v.eval(session)
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
return self._v.graph
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._v.device
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._v.shape
|
||||
|
||||
@property
|
||||
def aggregation(self):
|
||||
return self._aggregation
|
||||
|
||||
@property
|
||||
def synchronization(self):
|
||||
return self._v.synchronization
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._v.name
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
return self._v.trainable
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._v.dtype
|
||||
|
||||
# TODO(josh11b): Test saving & restoring.
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
return {trackable.VARIABLE_VALUE_KEY: self._v}
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o):
|
||||
return self._v + o
|
||||
|
||||
def __radd__(self, o):
|
||||
return o + self._v
|
||||
|
||||
def __sub__(self, o):
|
||||
return self._v - o
|
||||
|
||||
def __rsub__(self, o):
|
||||
return o - self._v
|
||||
|
||||
def __mul__(self, o):
|
||||
return self._v * o
|
||||
|
||||
def __rmul__(self, o):
|
||||
return o * self._v
|
||||
|
||||
def __truediv__(self, o):
|
||||
return self._v / o
|
||||
|
||||
def __rtruediv__(self, o):
|
||||
return o / self._v
|
||||
|
||||
def __floordiv__(self, o):
|
||||
return self._v // o
|
||||
|
||||
def __rfloordiv__(self, o):
|
||||
return o // self._v
|
||||
|
||||
def __mod__(self, o):
|
||||
return self._v % o
|
||||
|
||||
def __rmod__(self, o):
|
||||
return o % self._v
|
||||
|
||||
def __lt__(self, o):
|
||||
return self._v < o
|
||||
|
||||
def __le__(self, o):
|
||||
return self._v <= o
|
||||
|
||||
def __gt__(self, o):
|
||||
return self._v > o
|
||||
|
||||
def __ge__(self, o):
|
||||
return self._v >= o
|
||||
|
||||
def __and__(self, o):
|
||||
return self._v & o
|
||||
|
||||
def __rand__(self, o):
|
||||
return o & self._v
|
||||
|
||||
def __or__(self, o):
|
||||
return self._v | o
|
||||
|
||||
def __ror__(self, o):
|
||||
return o | self._v
|
||||
|
||||
def __xor__(self, o):
|
||||
return self._v ^ o
|
||||
|
||||
def __rxor__(self, o):
|
||||
return o ^ self._v
|
||||
|
||||
def __getitem__(self, o):
|
||||
return self._v[o]
|
||||
|
||||
def __pow__(self, o, modulo=None):
|
||||
return pow(self._v, o, modulo)
|
||||
|
||||
def __rpow__(self, o):
|
||||
return pow(o, self._v)
|
||||
|
||||
def __invert__(self):
|
||||
return ~self._v
|
||||
|
||||
def __neg__(self):
|
||||
return -self._v
|
||||
|
||||
def __abs__(self):
|
||||
return abs(self._v)
|
||||
|
||||
def __div__(self, o):
|
||||
try:
|
||||
return self._v.__div__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rdiv__(self, o):
|
||||
try:
|
||||
return self._v.__rdiv__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __matmul__(self, o):
|
||||
try:
|
||||
return self._v.__matmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, o):
|
||||
try:
|
||||
return self._v.__rmatmul__(o)
|
||||
except AttributeError:
|
||||
# See https://docs.python.org/3/library/constants.html#NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
def __str__(self):
|
||||
return str(self._v)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self._v)
|
||||
|
||||
def _should_act_as_resource_variable(self):
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
pass
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
return ops.convert_to_tensor(self.get(), dtype=dtype, name=name,
|
||||
as_ref=as_ref)
|
||||
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
|
||||
return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(AggregatingVariable,
|
||||
_tensor_conversion_aggregate)
|
||||
|
@ -2008,38 +2008,6 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
class AggregatingVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testAssignOutOfScope(self, distribution):
|
||||
with distribution.scope():
|
||||
aggregating = variables_lib.Variable(1.)
|
||||
self.assertIsInstance(aggregating, values.AggregatingVariable)
|
||||
self.evaluate(aggregating.assign(3.))
|
||||
self.assertEqual(self.evaluate(aggregating.read_value()), 3.)
|
||||
self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.)
|
||||
|
||||
def testAssignAdd(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
1, aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
@def_function.function
|
||||
def assign():
|
||||
return v.assign_add(2)
|
||||
|
||||
per_replica_results = self.evaluate(
|
||||
distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(assign)))
|
||||
self.assertAllEqual([3], per_replica_results)
|
||||
|
||||
|
||||
class MirroredTest(test.TestCase):
|
||||
|
||||
def testAddOp(self):
|
||||
|
91
tensorflow/python/distribute/values_util.py
Normal file
91
tensorflow/python/distribute/values_util.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility functions used by values.py and ps_values.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
|
||||
|
||||
def get_current_replica_id_as_int():
|
||||
"""Returns the current replica ID as an integer, or `None`."""
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if replica_context:
|
||||
replica_id = replica_context.replica_id_in_sync_group
|
||||
if not isinstance(replica_id, int):
|
||||
replica_id = tensor_util.constant_value(replica_id)
|
||||
else:
|
||||
replica_id = distribute_lib.get_update_replica_id()
|
||||
return replica_id
|
||||
|
||||
|
||||
def assign_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign(tensor)
|
||||
|
||||
|
||||
def assign_add_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign_add(tensor)
|
||||
|
||||
|
||||
def assign_sub_on_device(device, variable, tensor):
|
||||
with ops.device(device):
|
||||
return variable.assign_sub(tensor)
|
||||
|
||||
|
||||
def assert_replica_context(strategy):
|
||||
replica_context = ds_context.get_replica_context()
|
||||
if not replica_context:
|
||||
raise RuntimeError(
|
||||
"Replica-local variables may only be assigned in a replica context.")
|
||||
if replica_context.strategy is not strategy:
|
||||
raise RuntimeError(
|
||||
"Replica-local variables may only be assigned in a replica context.")
|
||||
|
||||
|
||||
def apply_aggregation(strategy, value, aggregation, destinations):
|
||||
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return strategy.extended.broadcast_to(
|
||||
strategy.experimental_local_results(value)[0],
|
||||
destinations=destinations)
|
||||
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
|
||||
return strategy.extended.reduce_to(reduce_op, value, destinations)
|
||||
|
||||
|
||||
aggregation_error_msg = (
|
||||
"You must specify an aggregation method to update a "
|
||||
"{variable_type} in Replica Context. You can do so by passing "
|
||||
"an explicit value for argument `aggregation` to tf.Variable(..)."
|
||||
"e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
|
||||
"`tf.VariableAggregation` lists the possible aggregation methods."
|
||||
"This is required because {variable_type} should always be "
|
||||
"kept in sync. When updating them or assigning to them in a "
|
||||
"replica context, we automatically try to aggregate the values "
|
||||
"before updating the variable. For this aggregation, we need to "
|
||||
"know the aggregation method. "
|
||||
"Another alternative is to not try to update such "
|
||||
"{variable_type} in replica context, but in cross replica "
|
||||
"context. You can enter cross replica context by calling "
|
||||
"`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
|
||||
"Inside `merge_fn`, you can then update the {variable_type} "
|
||||
"using `tf.distribute.StrategyExtended.update()`.")
|
@ -125,11 +125,14 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:ps_values",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/types",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import ps_values as ps_distribute_values
|
||||
from tensorflow.python.distribute import values as distribute_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
@ -437,7 +438,7 @@ def create_autocast_variable(variable):
|
||||
An AutoCastVariable that wraps the variable.
|
||||
"""
|
||||
if not isinstance(variable, (distribute_values.DistributedVariable,
|
||||
distribute_values.AggregatingVariable)):
|
||||
ps_distribute_values.AggregatingVariable)):
|
||||
return AutoCastVariable(variable)
|
||||
|
||||
class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
|
||||
@ -448,7 +449,8 @@ def create_autocast_variable(variable):
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
if issubclass(distribute_values.AggregatingVariable, variable.__class__):
|
||||
if issubclass(ps_distribute_values.AggregatingVariable,
|
||||
variable.__class__):
|
||||
# AggregatingVariable's __repr__ simply calls super.__repr__. So we do
|
||||
# the same here for consistency, which calls AutoCastVariable.__repr__.
|
||||
return super(AutoCastDistributedVariable, self).__repr__()
|
||||
|
@ -26,6 +26,7 @@ from absl.testing import parameterized
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as distributed_values
|
||||
from tensorflow.python.eager import context
|
||||
@ -250,7 +251,7 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
|
||||
None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
|
||||
tpu = tpu_values.TPUMirroredVariable(
|
||||
strategy=None, values=[variables.Variable(42.)], aggregation=None)
|
||||
aggregating = distributed_values.AggregatingVariable(
|
||||
aggregating = ps_values.AggregatingVariable(
|
||||
strategy=None, v=variables.Variable(1.), aggregation=None)
|
||||
|
||||
m = module.Module()
|
||||
@ -514,8 +515,8 @@ class FlattenTest(parameterized.TestCase, test_util.TensorFlowTestCase):
|
||||
|
||||
m = module.Module()
|
||||
m.layers = {non_orderable(): None, non_orderable(): None}
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Error processing property 'layers'"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Error processing property 'layers'"):
|
||||
m.variables # pylint: disable=pointless-statement
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user