A round of moving some DistributionStrategy libraries from contrib to
core (just implementations, not public APIs): * cross_tower_ops -> cross_device_ops * cross_tower_utils -> cross_device_utils * input_ops (and test) * shared_variable_creator (and test) * values Also: * BUILD clean up * renaming cross_tower -> cross_device in a number of places. PiperOrigin-RevId: 221828945
This commit is contained in:
parent
74660a6db3
commit
36035b9230
@ -26,7 +26,6 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/distribute/python:collective_all_reduce_strategy",
|
||||
"//tensorflow/contrib/distribute/python:cross_tower_ops",
|
||||
"//tensorflow/contrib/distribute/python:mirrored_strategy",
|
||||
"//tensorflow/contrib/distribute/python:monitor",
|
||||
"//tensorflow/contrib/distribute/python:one_device_strategy",
|
||||
@ -35,6 +34,7 @@ py_library(
|
||||
"//tensorflow/contrib/distribute/python:tpu_strategy",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/distribute:cross_device_ops",
|
||||
"//tensorflow/python/distribute:distribute_config",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
],
|
||||
|
@ -25,13 +25,13 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy
|
||||
from tensorflow.contrib.distribute.python.cross_tower_ops import *
|
||||
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
|
||||
from tensorflow.contrib.distribute.python.monitor import Monitor
|
||||
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
|
||||
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
|
||||
from tensorflow.contrib.distribute.python.step_fn import *
|
||||
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
|
||||
from tensorflow.python.distribute.cross_device_ops import *
|
||||
from tensorflow.python.distribute.distribute_config import DistributeConfig
|
||||
from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server
|
||||
from tensorflow.python.training.distribute import *
|
||||
|
@ -16,45 +16,24 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
# TODO(priyag): Figure out testonly issues that are preventing us from
|
||||
# including our tests in pip for now.
|
||||
|
||||
py_library(
|
||||
name = "values",
|
||||
srcs = ["values.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":input_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:device_util",
|
||||
"//tensorflow/python:distribute",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "values_test",
|
||||
srcs = ["values_test.py"],
|
||||
additional_deps = [
|
||||
":mirrored_strategy",
|
||||
":multi_worker_test_base",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:device_util",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python:device_util",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
@ -68,9 +47,6 @@ py_library(
|
||||
srcs = ["mirrored_strategy.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":cross_tower_ops",
|
||||
":shared_variable_creator",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -86,8 +62,11 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:cross_device_ops",
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:shared_variable_creator",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:tape",
|
||||
],
|
||||
@ -98,17 +77,17 @@ py_library(
|
||||
srcs = ["parameter_server_strategy.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":cross_tower_ops",
|
||||
":mirrored_strategy",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/distribute:cross_device_ops",
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
@ -121,7 +100,6 @@ cuda_py_test(
|
||||
":multi_worker_test_base",
|
||||
":parameter_server_strategy",
|
||||
":strategy_test_lib",
|
||||
":values",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -137,6 +115,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
@ -151,13 +130,13 @@ py_library(
|
||||
srcs = ["one_device_strategy.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:distribute",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -168,16 +147,16 @@ py_library(
|
||||
srcs = ["collective_all_reduce_strategy.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":cross_tower_ops",
|
||||
":cross_tower_utils",
|
||||
":mirrored_strategy",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/distribute:cross_device_ops",
|
||||
"//tensorflow/python/distribute:cross_device_utils",
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
@ -283,16 +262,16 @@ cuda_py_test(
|
||||
additional_deps = [
|
||||
":mirrored_strategy",
|
||||
":multi_worker_test_base",
|
||||
":values",
|
||||
":strategy_test_lib",
|
||||
"//tensorflow/python:distribute",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:distribute",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
@ -344,7 +323,6 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":one_device_strategy",
|
||||
":values",
|
||||
"//tensorflow/contrib/tpu:tpu_lib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
@ -353,6 +331,7 @@ py_library(
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
],
|
||||
)
|
||||
|
||||
@ -362,7 +341,6 @@ cuda_py_test(
|
||||
additional_deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":cross_tower_utils",
|
||||
":multi_worker_test_base",
|
||||
":strategy_test_lib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -378,6 +356,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:cross_device_utils",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
@ -508,7 +487,9 @@ cuda_py_test(
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/optimizer_v2:training",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute",
|
||||
"//tensorflow/python/distribute:distribute_config",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:distribute_coordinator_context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/feature_column",
|
||||
@ -600,52 +581,16 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "shared_variable_creator",
|
||||
srcs = ["shared_variable_creator.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "shared_variable_creator_test",
|
||||
srcs = ["shared_variable_creator_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":shared_variable_creator",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cross_tower_utils",
|
||||
srcs = ["cross_tower_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nccl_ops",
|
||||
"//tensorflow/python/distribute:all_reduce",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cross_tower_utils_test",
|
||||
srcs = ["cross_tower_utils_test.py"],
|
||||
name = "cross_device_utils_test",
|
||||
srcs = ["cross_device_utils_test.py"],
|
||||
additional_deps = [
|
||||
":combinations",
|
||||
":cross_tower_utils",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/distribute:cross_device_utils",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
@ -654,41 +599,20 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cross_tower_ops",
|
||||
srcs = ["cross_tower_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cross_tower_utils",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:device_lib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cross_tower_ops_test",
|
||||
srcs = ["cross_tower_ops_test.py"],
|
||||
name = "cross_device_ops_test",
|
||||
srcs = ["cross_device_ops_test.py"],
|
||||
additional_deps = [
|
||||
":combinations",
|
||||
":cross_tower_ops",
|
||||
":multi_worker_test_base",
|
||||
":mirrored_strategy",
|
||||
":values",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/distribute:cross_device_ops",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
@ -698,35 +622,6 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_ops",
|
||||
srcs = ["input_ops.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "input_ops_test",
|
||||
srcs = ["input_ops_test.py"],
|
||||
additional_deps = [
|
||||
":input_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
tags = [
|
||||
"no_pip",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "keras_test_lib",
|
||||
testonly = 1,
|
||||
|
@ -18,12 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
|
||||
from tensorflow.contrib.distribute.python import cross_tower_utils
|
||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -79,11 +79,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
else:
|
||||
local_devices = ["/device:CPU:0"]
|
||||
|
||||
self._collective_keys = cross_tower_utils.CollectiveKeys()
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||
super(CollectiveAllReduceExtended, self).__init__(
|
||||
container_strategy,
|
||||
devices=local_devices,
|
||||
cross_device_ops=cross_tower_ops_lib.CollectiveAllReduce(
|
||||
cross_device_ops=cross_device_ops_lib.CollectiveAllReduce(
|
||||
num_workers=1,
|
||||
num_gpus_per_worker=num_gpus_per_worker,
|
||||
collective_keys=self._collective_keys))
|
||||
@ -123,11 +123,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
else:
|
||||
local_devices = [worker_device]
|
||||
|
||||
self._collective_keys = cross_tower_utils.CollectiveKeys()
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||
super(CollectiveAllReduceExtended, self).__init__(
|
||||
container_strategy,
|
||||
devices=local_devices,
|
||||
cross_device_ops=cross_tower_ops_lib.CollectiveAllReduce(
|
||||
cross_device_ops=cross_device_ops_lib.CollectiveAllReduce(
|
||||
num_workers=self._num_workers,
|
||||
num_gpus_per_worker=num_gpus_per_worker,
|
||||
collective_keys=self._collective_keys))
|
||||
|
@ -23,11 +23,11 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
|
||||
from tensorflow.contrib.distribute.python import combinations
|
||||
from tensorflow.contrib.distribute.python import cross_tower_utils
|
||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||
from tensorflow.contrib.distribute.python import strategy_test_lib
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -73,7 +73,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
cluster_spec=self._cluster_spec,
|
||||
task_type=task_type,
|
||||
task_id=task_id)
|
||||
collective_keys = cross_tower_utils.CollectiveKeys(
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 * num_gpus +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||
instance_key_start=num_gpus * 100 +
|
||||
@ -81,7 +81,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
instance_key_with_id_start=num_gpus * 10000 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||
distribution.extended._collective_keys = collective_keys
|
||||
distribution.extended._cross_tower_ops._collective_keys = collective_keys
|
||||
distribution.extended._cross_device_ops._collective_keys = collective_keys
|
||||
if task_type and task_id is not None:
|
||||
return distribution, 'grpc://' + self._cluster_spec[task_type][
|
||||
task_id], session_config
|
||||
|
@ -24,13 +24,13 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distribute.python import combinations
|
||||
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
|
||||
from tensorflow.contrib.distribute.python import cross_tower_utils
|
||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||
from tensorflow.contrib.distribute.python import values as value_lib
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -41,7 +41,7 @@ from tensorflow.python.training import device_util
|
||||
|
||||
|
||||
def _make_per_replica(values, devices, regroup=False):
|
||||
devices = cross_tower_ops_lib.get_devices_from(devices)
|
||||
devices = cross_device_ops_lib.get_devices_from(devices)
|
||||
assert len(values) == len(devices)
|
||||
|
||||
# We simulate the result of regroup called on PerReplica which strips the
|
||||
@ -66,7 +66,7 @@ def _fake_mirrored(value, devices):
|
||||
All components of the returned Mirrored have the same objects, which is not
|
||||
true in reality.
|
||||
"""
|
||||
devices = cross_tower_ops_lib.get_devices_from(devices)
|
||||
devices = cross_device_ops_lib.get_devices_from(devices)
|
||||
return value_lib.Mirrored(
|
||||
{d: v for d, v in zip(devices, [value] * len(devices))})
|
||||
|
||||
@ -118,7 +118,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(
|
||||
sess.run(list(left._index.values())), list(right._index.values()))
|
||||
|
||||
def _testReductionAndBroadcast(self, cross_tower_ops, distribution):
|
||||
def _testReductionAndBroadcast(self, cross_device_ops, distribution):
|
||||
devices = distribution.worker_devices
|
||||
|
||||
values = [constant_op.constant(float(d)) for d in range(len(devices))]
|
||||
@ -142,24 +142,24 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
# test reduce()
|
||||
for destinations in all_destinations:
|
||||
self._assert_values_equal(
|
||||
cross_tower_ops.reduce(
|
||||
cross_device_ops.reduce(
|
||||
reduce_util.ReduceOp.MEAN,
|
||||
per_replica,
|
||||
destinations=destinations),
|
||||
_fake_mirrored(mean, destinations))
|
||||
self._assert_values_equal(
|
||||
cross_tower_ops.reduce(
|
||||
cross_device_ops.reduce(
|
||||
reduce_util.ReduceOp.MEAN,
|
||||
per_replica_2,
|
||||
destinations=destinations),
|
||||
_fake_mirrored(mean_2, destinations))
|
||||
self._assert_values_equal(
|
||||
cross_tower_ops.reduce(
|
||||
cross_device_ops.reduce(
|
||||
reduce_util.ReduceOp.SUM, per_replica,
|
||||
destinations=destinations),
|
||||
_fake_mirrored(mean * len(devices), destinations))
|
||||
self._assert_values_equal(
|
||||
cross_tower_ops.reduce(
|
||||
cross_device_ops.reduce(
|
||||
reduce_util.ReduceOp.SUM,
|
||||
per_replica_2,
|
||||
destinations=destinations),
|
||||
@ -168,7 +168,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
# test batch_reduce()
|
||||
for d1, d2 in itertools.product(all_destinations, all_destinations):
|
||||
self._assert_values_equal(
|
||||
cross_tower_ops.batch_reduce(
|
||||
cross_device_ops.batch_reduce(
|
||||
reduce_util.ReduceOp.MEAN,
|
||||
[(per_replica, d1), (per_replica_2, d2)]),
|
||||
[
|
||||
@ -176,7 +176,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
_fake_mirrored(mean_2, d2)
|
||||
])
|
||||
self._assert_values_equal(
|
||||
cross_tower_ops.batch_reduce(
|
||||
cross_device_ops.batch_reduce(
|
||||
reduce_util.ReduceOp.SUM,
|
||||
[(per_replica, d1), (per_replica_2, d2)]),
|
||||
[
|
||||
@ -187,7 +187,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
# test broadcast()
|
||||
for destinations in all_destinations:
|
||||
self._assert_values_equal(
|
||||
cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
|
||||
cross_device_ops.broadcast(constant_op.constant(1.), destinations),
|
||||
_fake_mirrored(1., destinations))
|
||||
|
||||
|
||||
@ -196,17 +196,17 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
# combinations module so that we can pass in devices instead of a distribution
|
||||
# strategy.
|
||||
reduction_to_one_combinations = combinations.combine(
|
||||
cross_tower_ops=[
|
||||
cross_device_ops=[
|
||||
combinations.NamedObject(
|
||||
"DefaultReductionToOneDeviceCrossDeviceOps",
|
||||
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
|
||||
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
|
||||
combinations.NamedObject(
|
||||
"ReductionToCPUDeviceCrossDeviceOps",
|
||||
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps(
|
||||
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps(
|
||||
reduce_to_device=_cpu_device)),
|
||||
combinations.NamedObject(
|
||||
"AccumulateNCrossDeviceOp",
|
||||
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps(
|
||||
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps(
|
||||
accumulation_fn=math_ops.accumulate_n)),
|
||||
],
|
||||
distribution=[
|
||||
@ -218,20 +218,20 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
],
|
||||
mode=["graph", "eager"])
|
||||
allreduce_combinations = combinations.combine(
|
||||
cross_tower_ops=[
|
||||
cross_device_ops=[
|
||||
combinations.NamedObject(
|
||||
"AllReduce",
|
||||
cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)),
|
||||
cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)),
|
||||
combinations.NamedObject(
|
||||
"HierarchicalCopy",
|
||||
cross_tower_ops_lib.AllReduceCrossDeviceOps(
|
||||
cross_device_ops_lib.AllReduceCrossDeviceOps(
|
||||
"hierarchical_copy", 8, 0, 0)),
|
||||
combinations.NamedObject(
|
||||
"AllReduceNoGradientRepacking",
|
||||
cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)),
|
||||
cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)),
|
||||
combinations.NamedObject(
|
||||
"HierarchicalCopyAggregateSmallTensors",
|
||||
cross_tower_ops_lib.AllReduceCrossDeviceOps(
|
||||
cross_device_ops_lib.AllReduceCrossDeviceOps(
|
||||
"hierarchical_copy", 0, 100, 10))
|
||||
],
|
||||
distribution=[combinations.mirrored_strategy_with_two_gpus,
|
||||
@ -239,22 +239,22 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
mode=["graph", "eager"])
|
||||
|
||||
@combinations.generate(reduction_to_one_combinations + allreduce_combinations)
|
||||
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
|
||||
def testReductionAndBroadcast(self, cross_device_ops, distribution):
|
||||
with distribution.scope():
|
||||
self._testReductionAndBroadcast(cross_tower_ops, distribution)
|
||||
self._testReductionAndBroadcast(cross_device_ops, distribution)
|
||||
|
||||
def testChooseAlgorithm(self):
|
||||
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
|
||||
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
|
||||
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
|
||||
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
|
||||
self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
|
||||
self.assertEqual(result._num_packs, 8)
|
||||
|
||||
# if there are only 4 devices
|
||||
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
|
||||
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
|
||||
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
|
||||
self.assertEqual(result._all_reduce_alg, "nccl")
|
||||
self.assertEqual(result._num_packs, 1)
|
||||
|
||||
@ -262,16 +262,16 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
|
||||
[0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
|
||||
[2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
|
||||
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
|
||||
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
|
||||
self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
|
||||
self.assertEqual(result._num_packs, 8)
|
||||
|
||||
# if not dgx1-like links
|
||||
device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
|
||||
[1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
|
||||
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
|
||||
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
|
||||
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
|
||||
self.assertEqual(result._all_reduce_alg, "nccl")
|
||||
self.assertEqual(result._num_packs, 1)
|
||||
|
||||
@ -283,7 +283,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
|
||||
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
|
||||
per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
|
||||
result = cross_tower_ops_lib._simple_reduce(
|
||||
result = cross_device_ops_lib._simple_reduce(
|
||||
per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM)
|
||||
|
||||
# Test that the result is semantically equal to both the concatenated
|
||||
@ -297,19 +297,19 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
cross_tower_ops_instance=[
|
||||
cross_device_ops_instance=[
|
||||
combinations.NamedObject(
|
||||
"ReductionToOneDeviceCrossDeviceOps",
|
||||
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
|
||||
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
|
||||
combinations.NamedObject(
|
||||
"AllReduceCrossDeviceOps",
|
||||
cross_tower_ops_lib.AllReduceCrossDeviceOps())
|
||||
cross_device_ops_lib.AllReduceCrossDeviceOps())
|
||||
],
|
||||
reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN],
|
||||
batch_reduce=[True, False],
|
||||
mode=["graph", "eager"],
|
||||
required_gpus=1))
|
||||
def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, reduce_op,
|
||||
def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op,
|
||||
batch_reduce):
|
||||
devices = ["/cpu:0", "/gpu:0"]
|
||||
dense_shape = [5, 2]
|
||||
@ -319,10 +319,10 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
|
||||
|
||||
if batch_reduce:
|
||||
result = cross_tower_ops_instance.batch_reduce(
|
||||
result = cross_device_ops_instance.batch_reduce(
|
||||
reduce_op, [(per_replica, devices)])
|
||||
else:
|
||||
result = cross_tower_ops_instance.reduce(
|
||||
result = cross_device_ops_instance.reduce(
|
||||
reduce_op, per_replica, devices)
|
||||
|
||||
total_indices_with_dups = [1, 1, 3]
|
||||
@ -359,22 +359,22 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
"/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
|
||||
]
|
||||
multi_worker_allreduce_combinations = combinations.combine(
|
||||
cross_tower_ops=[
|
||||
cross_device_ops=[
|
||||
combinations.NamedObject(
|
||||
"MultiWorkerAllReduce",
|
||||
cross_tower_ops_lib.MultiWorkerAllReduce(
|
||||
cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)),
|
||||
combinations.NamedObject(
|
||||
"MultiWorkerAllReducePack",
|
||||
cross_tower_ops_lib.MultiWorkerAllReduce(
|
||||
cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)),
|
||||
combinations.NamedObject(
|
||||
"MultiWorkerAllReduceAggregation",
|
||||
cross_tower_ops_lib.MultiWorkerAllReduce(
|
||||
cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)),
|
||||
combinations.NamedObject(
|
||||
"MultiWorkerAllReduceMultipleSpecs",
|
||||
cross_tower_ops_lib.MultiWorkerAllReduce(
|
||||
cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
worker_devices, 2, [("pscpu/pscpu", 2, 100),
|
||||
("xring", 2, -1)], 0, 0, 0)),
|
||||
],
|
||||
@ -395,13 +395,13 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
mode=["graph"])
|
||||
|
||||
@combinations.generate(multi_worker_allreduce_combinations)
|
||||
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
|
||||
def testReductionAndBroadcast(self, cross_device_ops, distribution):
|
||||
distribution.configure(cluster_spec={
|
||||
"worker":
|
||||
["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"]
|
||||
})
|
||||
with distribution.scope():
|
||||
self._testReductionAndBroadcast(cross_tower_ops, distribution)
|
||||
self._testReductionAndBroadcast(cross_device_ops, distribution)
|
||||
|
||||
|
||||
class MultiWorkerCollectiveAllReduceTest(
|
||||
@ -422,7 +422,7 @@ class MultiWorkerCollectiveAllReduceTest(
|
||||
MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000
|
||||
|
||||
def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False):
|
||||
collective_keys = cross_tower_utils.CollectiveKeys(
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 * num_gpus +
|
||||
MultiWorkerCollectiveAllReduceTest.collective_key_base,
|
||||
instance_key_start=num_gpus * 100 +
|
||||
@ -430,7 +430,7 @@ class MultiWorkerCollectiveAllReduceTest(
|
||||
instance_key_with_id_start=num_gpus * 10000 +
|
||||
MultiWorkerCollectiveAllReduceTest.collective_key_base)
|
||||
if local_mode:
|
||||
collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
1, num_gpus, collective_keys=collective_keys)
|
||||
if num_gpus:
|
||||
devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
|
||||
@ -438,7 +438,7 @@ class MultiWorkerCollectiveAllReduceTest(
|
||||
devices = ["/device:CPU:0"]
|
||||
return collective_all_reduce_ops, devices, ""
|
||||
else:
|
||||
collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
3, num_gpus, collective_keys=collective_keys)
|
||||
if num_gpus:
|
||||
devices = [
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for cross_tower_utils."""
|
||||
"""Tests for cross_device_utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,8 +21,8 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.contrib.distribute.python import combinations
|
||||
from tensorflow.contrib.distribute.python import cross_tower_utils
|
||||
from tensorflow.contrib.distribute.python import values as value_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
@ -43,7 +43,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
|
||||
t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
|
||||
total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
|
||||
result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
|
||||
result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1])
|
||||
self._assert_values_equal(total, result)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@ -53,7 +53,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
t1 = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
|
||||
total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
|
||||
result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
|
||||
result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1])
|
||||
self.assertIsInstance(result, ops.IndexedSlices)
|
||||
self._assert_values_equal(total, result)
|
||||
|
||||
@ -62,7 +62,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
|
||||
n = 2
|
||||
expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
|
||||
result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
|
||||
result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n)
|
||||
self._assert_values_equal(expected, result)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@ -71,7 +71,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
|
||||
n = 2
|
||||
expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
|
||||
result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
|
||||
result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n)
|
||||
self.assertIsInstance(result, ops.IndexedSlices)
|
||||
self._assert_values_equal(expected, result)
|
||||
|
||||
@ -79,7 +79,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
def testIsIndexedSlices(self):
|
||||
t = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
|
||||
self.assertTrue(cross_tower_utils.contains_indexed_slices(t))
|
||||
self.assertTrue(cross_device_utils.contains_indexed_slices(t))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testContainsIndexedSlices_List(self):
|
||||
@ -87,7 +87,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
|
||||
t1 = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
|
||||
self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1]))
|
||||
self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testContainsIndexedSlices_Tuple(self):
|
||||
@ -95,7 +95,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
|
||||
t1 = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
|
||||
self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
|
||||
self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testContainsIndexedSlices_PerReplica(self):
|
||||
@ -104,7 +104,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
t1 = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
|
||||
per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1})
|
||||
self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
|
||||
self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica))
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
@ -113,7 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device("/cpu:0"):
|
||||
t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
|
||||
destination = "/gpu:0"
|
||||
result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
|
||||
result = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
|
||||
t, destination)
|
||||
|
||||
self._assert_values_equal(t, result)
|
||||
@ -128,7 +128,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
t = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
|
||||
destination = "/gpu:0"
|
||||
result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
|
||||
result = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
|
||||
t, destination)
|
||||
|
||||
self.assertIsInstance(result, ops.IndexedSlices)
|
@ -24,9 +24,9 @@ import numpy as np
|
||||
from tensorflow.contrib.distribute.python import combinations
|
||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||
from tensorflow.contrib.distribute.python import tpu_strategy
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.estimator import keras as keras_lib
|
||||
from tensorflow.python.estimator import run_config as run_config_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
@ -22,12 +22,12 @@ import contextlib
|
||||
from functools import partial
|
||||
import threading
|
||||
|
||||
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
|
||||
from tensorflow.contrib.distribute.python import shared_variable_creator
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import shared_variable_creator
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -196,16 +196,16 @@ def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
return value
|
||||
|
||||
cross_tower_ops_lib.validate_destinations(destinations)
|
||||
cross_device_ops_lib.validate_destinations(destinations)
|
||||
# We do not support a reduce op of SUM if the value is the same across
|
||||
# all replicas. We call this as part of assign functions for MirroredVariables
|
||||
# and summing up identical values across replicas is not clearly defined.
|
||||
if (len(extended.worker_devices) != 1 or
|
||||
not cross_tower_ops_lib.check_destinations(destinations)):
|
||||
not cross_device_ops_lib.check_destinations(destinations)):
|
||||
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
|
||||
"the given reduce op %s." % (value, reduce_op))
|
||||
# TODO(anjalisridhar): Moves these methods to a device utility file?
|
||||
devices = cross_tower_ops_lib.get_devices_from(destinations)
|
||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||
if len(devices) == 1:
|
||||
with ops.device(devices[0]):
|
||||
return array_ops.identity(value)
|
||||
@ -369,8 +369,7 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
cross_device_ops=None,
|
||||
auto_shard_dataset=False):
|
||||
super(CoreMirroredExtended, self).__init__(container_strategy)
|
||||
# TODO(josh11b): Rename self._cross_tower_ops -> self._cross_device_ops
|
||||
self._cross_tower_ops = cross_device_ops
|
||||
self._cross_device_ops = cross_device_ops
|
||||
self._auto_shard_dataset = auto_shard_dataset
|
||||
# Remember num GPUs which might be needed by `configure` method.
|
||||
if num_gpus is not None and num_gpus_per_worker is not None:
|
||||
@ -588,7 +587,7 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
if isinstance(tensor, (float, int)): # Fast path for Python constants.
|
||||
return tensor
|
||||
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
|
||||
return self._get_cross_tower_ops().broadcast(
|
||||
return self._get_cross_device_ops().broadcast(
|
||||
tensor, destinations or self._devices)
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
@ -607,26 +606,27 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
if cluster_spec:
|
||||
self._initialize_multi_worker(self._num_gpus, cluster_spec)
|
||||
|
||||
if self._cross_tower_ops is None:
|
||||
if self._cross_device_ops is None:
|
||||
if self._cluster_spec:
|
||||
# It currently cannot detect the toplogy of remote workers. So we
|
||||
# hard-code the multi-worker all-reduce algorithm for now.
|
||||
if len(self._workers) == 1:
|
||||
# The default is "nccl".
|
||||
self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossDeviceOps()
|
||||
self._cross_device_ops = (
|
||||
cross_device_ops_lib.AllReduceCrossDeviceOps())
|
||||
else:
|
||||
# The default is hierarchical reduce and broadcast.
|
||||
self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
|
||||
self._cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
self._workers, self._num_gpus)
|
||||
else:
|
||||
self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
|
||||
self._cross_device_ops = cross_device_ops_lib.choose_the_best(
|
||||
self._devices, session_config=session_config)
|
||||
|
||||
def _get_cross_tower_ops(self):
|
||||
if self._cross_tower_ops is None:
|
||||
self._cross_tower_ops = (
|
||||
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps())
|
||||
return self._cross_tower_ops
|
||||
def _get_cross_device_ops(self):
|
||||
if self._cross_device_ops is None:
|
||||
self._cross_device_ops = (
|
||||
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps())
|
||||
return self._cross_device_ops
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
assert not isinstance(value, values.Mirrored)
|
||||
@ -637,12 +637,12 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
# be 0.
|
||||
return _reduce_non_distributed_value(self, reduce_op, value,
|
||||
destinations)
|
||||
return self._get_cross_tower_ops().reduce(
|
||||
return self._get_cross_device_ops().reduce(
|
||||
reduce_op, value, destinations=destinations)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||
return self._get_cross_tower_ops().batch_reduce(reduce_op,
|
||||
value_destination_pairs)
|
||||
return self._get_cross_device_ops().batch_reduce(reduce_op,
|
||||
value_destination_pairs)
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
# TODO(josh11b): In eager mode, use one thread per device.
|
||||
@ -723,7 +723,7 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
if colocate_with is None:
|
||||
return self._devices
|
||||
else:
|
||||
return cross_tower_ops_lib.get_devices_from(colocate_with)
|
||||
return cross_device_ops_lib.get_devices_from(colocate_with)
|
||||
|
||||
class _MirroredReplicaThread(threading.Thread):
|
||||
"""A thread that runs() a function on a device."""
|
||||
|
@ -25,10 +25,10 @@ import numpy as np
|
||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||
from tensorflow.contrib.distribute.python import strategy_test_lib
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
@ -18,10 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
|
||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import ops
|
||||
@ -107,8 +107,8 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
||||
self._initialize_local(num_gpus_per_worker)
|
||||
|
||||
# We typically don't need to do all-reduce in this strategy.
|
||||
self._cross_tower_ops = (
|
||||
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps(
|
||||
self._cross_device_ops = (
|
||||
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps(
|
||||
reduce_to_device=_LOCAL_CPU))
|
||||
|
||||
def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
|
||||
@ -256,9 +256,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
||||
True)
|
||||
|
||||
def _broadcast_to(self, tensor, destinations):
|
||||
if not cross_tower_ops_lib.check_destinations(destinations):
|
||||
if not cross_device_ops_lib.check_destinations(destinations):
|
||||
destinations = self._compute_devices
|
||||
return self._cross_tower_ops.broadcast(tensor, destinations)
|
||||
return self._cross_device_ops.broadcast(tensor, destinations)
|
||||
|
||||
def _allow_variable_partition(self):
|
||||
return not context.executing_eagerly()
|
||||
@ -330,7 +330,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
||||
return
|
||||
if destinations is None:
|
||||
return
|
||||
for d in cross_tower_ops_lib.get_devices_from(destinations):
|
||||
for d in cross_device_ops_lib.get_devices_from(destinations):
|
||||
d_spec = tf_device.DeviceSpec.from_string(d)
|
||||
if d_spec.job == self._task_type and d_spec.task != self._task_id:
|
||||
raise ValueError(
|
||||
@ -343,14 +343,14 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
|
||||
# pylint: disable=protected-access
|
||||
return mirrored_strategy._reduce_non_distributed_value(
|
||||
self, reduce_op, value, destinations)
|
||||
return self._cross_tower_ops.reduce(
|
||||
return self._cross_device_ops.reduce(
|
||||
reduce_op, value, destinations=destinations)
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
||||
for _, destinations in value_destination_pairs:
|
||||
self._verify_destinations_not_different_worker(destinations)
|
||||
return self._cross_tower_ops.batch_reduce(reduce_op,
|
||||
value_destination_pairs)
|
||||
return self._cross_device_ops.batch_reduce(reduce_op,
|
||||
value_destination_pairs)
|
||||
|
||||
def _select_single_value(self, structured):
|
||||
"""Select any single values in `structured`."""
|
||||
|
@ -26,10 +26,10 @@ from tensorflow.contrib.distribute.python import combinations
|
||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||
from tensorflow.contrib.distribute.python import parameter_server_strategy
|
||||
from tensorflow.contrib.distribute.python import strategy_test_lib
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.estimator import run_config
|
||||
|
@ -23,15 +23,15 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
|
||||
from tensorflow.contrib.tpu.python.tpu import training_loop
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -467,7 +467,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
# Validate that the destination is same as the host device
|
||||
# Note we don't do this when in replicate context as the reduction is
|
||||
# performed on the TPU device itself.
|
||||
devices = cross_tower_ops_lib.get_devices_from(destinations)
|
||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||
if len(devices) == 1:
|
||||
assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
|
||||
self._host_device)
|
||||
|
@ -22,9 +22,9 @@ import os
|
||||
|
||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||
from tensorflow.contrib.distribute.python import values
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
|
@ -8,6 +8,7 @@ exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "all_reduce",
|
||||
@ -44,13 +45,41 @@ tf_py_test(
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distribute",
|
||||
name = "cross_device_ops",
|
||||
srcs = ["cross_device_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":distribute_config",
|
||||
":distribute_coordinator",
|
||||
":distribute_coordinator_context",
|
||||
":cross_device_utils",
|
||||
":reduce_util",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:device_lib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:context",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cross_device_utils",
|
||||
srcs = ["cross_device_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":all_reduce",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nccl_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -123,6 +152,34 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_ops",
|
||||
srcs = ["input_ops.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "input_ops_test",
|
||||
srcs = ["input_ops_test.py"],
|
||||
additional_deps = [
|
||||
":input_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
tags = [
|
||||
"no_pip",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multi_worker_util_test",
|
||||
srcs = ["multi_worker_util_test.py"],
|
||||
@ -158,8 +215,43 @@ py_library(
|
||||
|
||||
py_library(
|
||||
name = "reduce_util",
|
||||
srcs = [
|
||||
"reduce_util.py",
|
||||
],
|
||||
srcs = ["reduce_util.py"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "shared_variable_creator",
|
||||
srcs = ["shared_variable_creator.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "shared_variable_creator_test",
|
||||
srcs = ["shared_variable_creator_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":shared_variable_creator",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "values",
|
||||
srcs = ["values.py"],
|
||||
deps = [
|
||||
":input_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:device_util",
|
||||
"//tensorflow/python:distribute",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -21,10 +21,10 @@ from __future__ import print_function
|
||||
import collections
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distribute.python import cross_tower_utils
|
||||
from tensorflow.contrib.distribute.python import values as value_lib
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -144,7 +144,7 @@ def _simple_broadcast(value, destinations):
|
||||
index = {}
|
||||
devices = get_devices_from(destinations)
|
||||
for d in devices:
|
||||
index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
|
||||
index[d] = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
|
||||
value, d)
|
||||
return value_lib.Mirrored(index)
|
||||
|
||||
@ -162,10 +162,10 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
||||
|
||||
with ops.device(reduce_to_device):
|
||||
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
|
||||
reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
|
||||
all_values, accumulation_fn)
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
|
||||
reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
|
||||
reduced, count)
|
||||
elif reduce_op != reduce_util.ReduceOp.SUM:
|
||||
raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
|
||||
@ -332,7 +332,7 @@ def _ungroup_and_make_mirrored(grouped_reduced,
|
||||
Args:
|
||||
grouped_reduced: a list of lists, each sublist has components for each
|
||||
device, paired with a None. It is the result from
|
||||
cross_tower_utils.aggregate_gradients_using*.
|
||||
cross_device_utils.aggregate_gradients_using*.
|
||||
destinations: a list of device strings for returned Mirrored objects.
|
||||
reduce_op: Indicates how values will be aggregated. Accepted values
|
||||
are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
|
||||
@ -485,7 +485,7 @@ class AggregateSmallTensorPacker(object):
|
||||
"""Aggregate small tensors."""
|
||||
if (self.agg_small_grads_max_bytes > 0 and
|
||||
self.agg_small_grads_max_group > 0):
|
||||
device_grads, self.packing = cross_tower_utils.pack_small_tensors(
|
||||
device_grads, self.packing = cross_device_utils.pack_small_tensors(
|
||||
grouped_grads_and_vars,
|
||||
max_bytes=self.agg_small_grads_max_bytes,
|
||||
max_group=self.agg_small_grads_max_group)
|
||||
@ -493,8 +493,8 @@ class AggregateSmallTensorPacker(object):
|
||||
|
||||
def unpack(self, summed_device_grad_packs):
|
||||
"""Reverse the aggregation process."""
|
||||
return cross_tower_utils.unpack_small_tensors(summed_device_grad_packs,
|
||||
self.packing)
|
||||
return cross_device_utils.unpack_small_tensors(summed_device_grad_packs,
|
||||
self.packing)
|
||||
|
||||
|
||||
def _pack_tensors(device_grads,
|
||||
@ -557,7 +557,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
||||
super(AllReduceCrossDeviceOps, self).__init__()
|
||||
|
||||
def _reduce(self, reduce_op, per_replica_value, destinations):
|
||||
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
|
||||
contains_indexed_slices = cross_device_utils.contains_indexed_slices(
|
||||
per_replica_value)
|
||||
if (_devices_match(per_replica_value, destinations)
|
||||
and not context.executing_eagerly()
|
||||
@ -580,7 +580,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
||||
|
||||
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||
all_devices_match = _all_devices_match(value_destination_pairs)
|
||||
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
|
||||
contains_indexed_slices = cross_device_utils.contains_indexed_slices(
|
||||
value_destination_pairs)
|
||||
if (all_devices_match and not context.executing_eagerly()
|
||||
and not contains_indexed_slices):
|
||||
@ -618,13 +618,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
||||
# the balance on num_splits.
|
||||
if self._all_reduce_alg == "nccl":
|
||||
# TODO(yuefengz): merge this into the all-reduce library.
|
||||
reduced = cross_tower_utils.aggregate_gradients_using_nccl(
|
||||
reduced = cross_device_utils.aggregate_gradients_using_nccl(
|
||||
device_grad_packs)
|
||||
else:
|
||||
# TODO(yuefengz): check that gpu ids in `destinations` are in ascending
|
||||
# order.
|
||||
reduced = (
|
||||
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
|
||||
cross_device_utils.aggregate_gradients_using_hierarchical_copy(
|
||||
destinations, device_grad_packs))
|
||||
|
||||
reduced = _unpack_tensors(reduced, tensor_packer)
|
||||
@ -740,13 +740,13 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
|
||||
this_grads = remaining_grads
|
||||
remaining_grads = []
|
||||
else:
|
||||
(this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size(
|
||||
(this_grads, remaining_grads) = cross_device_utils.split_grads_by_size(
|
||||
spec_tuple.limit, remaining_grads)
|
||||
if this_grads:
|
||||
device_grad_packs, tensor_packer = _pack_tensors(
|
||||
this_grads, self._num_packs, self._agg_small_grads_max_bytes,
|
||||
self._agg_small_grads_max_group)
|
||||
range_agg_grads = cross_tower_utils.sum_gradients_all_reduce(
|
||||
range_agg_grads = cross_device_utils.sum_gradients_all_reduce(
|
||||
self._worker_devices, device_grad_packs, len(self._worker_devices),
|
||||
spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
|
||||
range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
|
||||
@ -789,13 +789,13 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
self._num_workers = num_workers
|
||||
self._num_gpus_per_worker = num_gpus_per_worker
|
||||
self._all_reduce_merge_scope = all_reduce_merge_scope
|
||||
self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys(
|
||||
)
|
||||
self._collective_keys = (collective_keys or
|
||||
cross_device_utils.CollectiveKeys())
|
||||
super(CollectiveAllReduce, self).__init__()
|
||||
|
||||
# TODO(yuefengz, tucker): is indexed slices supported by collective ops?
|
||||
def _reduce(self, reduce_op, per_replica_value, destinations):
|
||||
if cross_tower_utils.contains_indexed_slices(per_replica_value):
|
||||
if cross_device_utils.contains_indexed_slices(per_replica_value):
|
||||
raise ValueError(
|
||||
"`IndexSlices` is not supported for Collective All-Reduce.")
|
||||
if context.executing_eagerly():
|
||||
@ -819,7 +819,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
return value_lib.Mirrored(index)
|
||||
|
||||
def _batch_reduce(self, reduce_op, value_destination_pairs):
|
||||
if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
|
||||
if cross_device_utils.contains_indexed_slices(value_destination_pairs):
|
||||
raise ValueError(
|
||||
"`IndexSlices` is not supported for Collective All-Reduce.")
|
||||
if context.executing_eagerly():
|
||||
@ -870,7 +870,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
with ops.name_scope("allreduce"):
|
||||
for grad_and_vars in chunk:
|
||||
scaled_grads = [g for g, _ in grad_and_vars]
|
||||
collective_reduced = cross_tower_utils.build_collective_reduce(
|
||||
collective_reduced = cross_device_utils.build_collective_reduce(
|
||||
scaled_grads, self._num_workers, self._collective_keys, "Add",
|
||||
"Id")
|
||||
result = []
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for cross_tower_ops."""
|
||||
"""Utilities for cross_device_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,8 +21,8 @@ from __future__ import print_function
|
||||
import collections as pycoll
|
||||
import threading
|
||||
|
||||
from tensorflow.contrib.distribute.python import values as value_lib
|
||||
from tensorflow.python.distribute import all_reduce
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
@ -20,9 +20,9 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.distribute.python import input_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.distribute import input_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.platform import test
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distribute.python import shared_variable_creator
|
||||
from tensorflow.python.distribute import shared_variable_creator
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variable_scope
|
@ -27,9 +27,9 @@ import operator
|
||||
import weakref
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distribute.python import input_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.distribute import input_ops
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
Loading…
Reference in New Issue
Block a user