A round of moving some DistributionStrategy libraries from contrib to

core (just implementations, not public APIs):
  * cross_tower_ops -> cross_device_ops
  * cross_tower_utils -> cross_device_utils
  * input_ops (and test)
  * shared_variable_creator (and test)
  * values

Also:
  * BUILD clean up
  * renaming cross_tower -> cross_device in a number of places.
PiperOrigin-RevId: 221828945
This commit is contained in:
A. Unique TensorFlower 2018-11-16 11:54:10 -08:00 committed by TensorFlower Gardener
parent 74660a6db3
commit 36035b9230
23 changed files with 268 additions and 281 deletions

View File

@ -26,7 +26,6 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/contrib/distribute/python:collective_all_reduce_strategy",
"//tensorflow/contrib/distribute/python:cross_tower_ops",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
"//tensorflow/contrib/distribute/python:monitor",
"//tensorflow/contrib/distribute/python:one_device_strategy",
@ -35,6 +34,7 @@ py_library(
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:distribute_config",
"//tensorflow/python/distribute:distribute_coordinator",
],

View File

@ -25,13 +25,13 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy
from tensorflow.contrib.distribute.python.cross_tower_ops import *
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
from tensorflow.contrib.distribute.python.monitor import Monitor
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.distribute.cross_device_ops import *
from tensorflow.python.distribute.distribute_config import DistributeConfig
from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server
from tensorflow.python.training.distribute import *

View File

@ -16,45 +16,24 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
# TODO(priyag): Figure out testonly issues that are preventing us from
# including our tests in pip for now.
py_library(
name = "values",
srcs = ["values.py"],
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
],
)
cuda_py_test(
name = "values_test",
srcs = ["values_test.py"],
additional_deps = [
":mirrored_strategy",
":multi_worker_test_base",
":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python:errors",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:device_util",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python:device_util",
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:estimator_py",
],
@ -68,9 +47,6 @@ py_library(
srcs = ["mirrored_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
":cross_tower_ops",
":shared_variable_creator",
":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
@ -86,8 +62,11 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:shared_variable_creator",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
],
@ -98,17 +77,17 @@ py_library(
srcs = ["parameter_server_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
":cross_tower_ops",
":mirrored_strategy",
":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
],
)
@ -121,7 +100,6 @@ cuda_py_test(
":multi_worker_test_base",
":parameter_server_strategy",
":strategy_test_lib",
":values",
"@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@ -137,6 +115,7 @@ cuda_py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/estimator:estimator_py",
],
@ -151,13 +130,13 @@ py_library(
srcs = ["one_device_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:distribute",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"@six_archive//:six",
],
@ -168,16 +147,16 @@ py_library(
srcs = ["collective_all_reduce_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
":cross_tower_ops",
":cross_tower_utils",
":mirrored_strategy",
":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:cross_device_utils",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
],
)
@ -283,16 +262,16 @@ cuda_py_test(
additional_deps = [
":mirrored_strategy",
":multi_worker_test_base",
":values",
":strategy_test_lib",
"//tensorflow/python:distribute",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:layers",
"//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
@ -344,7 +323,6 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":one_device_strategy",
":values",
"//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
@ -353,6 +331,7 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
],
)
@ -362,7 +341,6 @@ cuda_py_test(
additional_deps = [
":collective_all_reduce_strategy",
":combinations",
":cross_tower_utils",
":multi_worker_test_base",
":strategy_test_lib",
"@absl_py//absl/testing:parameterized",
@ -378,6 +356,7 @@ cuda_py_test(
"//tensorflow/python:layers",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:cross_device_utils",
"//tensorflow/python/eager:context",
"//tensorflow/python/estimator:estimator_py",
],
@ -508,7 +487,9 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/contrib/optimizer_v2:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute",
"//tensorflow/python/distribute:distribute_config",
"//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column",
@ -600,52 +581,16 @@ cuda_py_test(
],
)
py_library(
name = "shared_variable_creator",
srcs = ["shared_variable_creator.py"],
visibility = ["//tensorflow:internal"],
)
py_test(
name = "shared_variable_creator_test",
srcs = ["shared_variable_creator_test.py"],
srcs_version = "PY2AND3",
deps = [
":shared_variable_creator",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:test",
],
)
py_library(
name = "cross_tower_utils",
srcs = ["cross_tower_utils.py"],
srcs_version = "PY2AND3",
deps = [
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops",
"//tensorflow/python:device",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops",
"//tensorflow/python/distribute:all_reduce",
],
)
cuda_py_test(
name = "cross_tower_utils_test",
srcs = ["cross_tower_utils_test.py"],
name = "cross_device_utils_test",
srcs = ["cross_device_utils_test.py"],
additional_deps = [
":combinations",
":cross_tower_utils",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/distribute:cross_device_utils",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
@ -654,41 +599,20 @@ cuda_py_test(
],
)
py_library(
name = "cross_tower_ops",
srcs = ["cross_tower_ops.py"],
srcs_version = "PY2AND3",
deps = [
":cross_tower_utils",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:device_lib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/eager:context",
"@six_archive//:six",
],
)
cuda_py_test(
name = "cross_tower_ops_test",
srcs = ["cross_tower_ops_test.py"],
name = "cross_device_ops_test",
srcs = ["cross_device_ops_test.py"],
additional_deps = [
":combinations",
":cross_tower_ops",
":multi_worker_test_base",
":mirrored_strategy",
":values",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
@ -698,35 +622,6 @@ cuda_py_test(
],
)
py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/util:nest",
],
)
cuda_py_test(
name = "input_ops_test",
srcs = ["input_ops_test.py"],
additional_deps = [
":input_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python:errors",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:io_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python:util",
],
tags = [
"no_pip",
],
)
py_library(
name = "keras_test_lib",
testonly = 1,

View File

@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -79,11 +79,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else:
local_devices = ["/device:CPU:0"]
self._collective_keys = cross_tower_utils.CollectiveKeys()
self._collective_keys = cross_device_utils.CollectiveKeys()
super(CollectiveAllReduceExtended, self).__init__(
container_strategy,
devices=local_devices,
cross_device_ops=cross_tower_ops_lib.CollectiveAllReduce(
cross_device_ops=cross_device_ops_lib.CollectiveAllReduce(
num_workers=1,
num_gpus_per_worker=num_gpus_per_worker,
collective_keys=self._collective_keys))
@ -123,11 +123,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else:
local_devices = [worker_device]
self._collective_keys = cross_tower_utils.CollectiveKeys()
self._collective_keys = cross_device_utils.CollectiveKeys()
super(CollectiveAllReduceExtended, self).__init__(
container_strategy,
devices=local_devices,
cross_device_ops=cross_tower_ops_lib.CollectiveAllReduce(
cross_device_ops=cross_device_ops_lib.CollectiveAllReduce(
num_workers=self._num_workers,
num_gpus_per_worker=num_gpus_per_worker,
collective_keys=self._collective_keys))

View File

@ -23,11 +23,11 @@ import numpy as np
from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -73,7 +73,7 @@ class CollectiveAllReduceStrategyTestBase(
cluster_spec=self._cluster_spec,
task_type=task_type,
task_id=task_id)
collective_keys = cross_tower_utils.CollectiveKeys(
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_start=num_gpus * 100 +
@ -81,7 +81,7 @@ class CollectiveAllReduceStrategyTestBase(
instance_key_with_id_start=num_gpus * 10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
distribution.extended._collective_keys = collective_keys
distribution.extended._cross_tower_ops._collective_keys = collective_keys
distribution.extended._cross_device_ops._collective_keys = collective_keys
if task_type and task_id is not None:
return distribution, 'grpc://' + self._cluster_spec[task_type][
task_id], session_config

View File

@ -24,13 +24,13 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
@ -41,7 +41,7 @@ from tensorflow.python.training import device_util
def _make_per_replica(values, devices, regroup=False):
devices = cross_tower_ops_lib.get_devices_from(devices)
devices = cross_device_ops_lib.get_devices_from(devices)
assert len(values) == len(devices)
# We simulate the result of regroup called on PerReplica which strips the
@ -66,7 +66,7 @@ def _fake_mirrored(value, devices):
All components of the returned Mirrored have the same objects, which is not
true in reality.
"""
devices = cross_tower_ops_lib.get_devices_from(devices)
devices = cross_device_ops_lib.get_devices_from(devices)
return value_lib.Mirrored(
{d: v for d, v in zip(devices, [value] * len(devices))})
@ -118,7 +118,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
self.assertEqual(
sess.run(list(left._index.values())), list(right._index.values()))
def _testReductionAndBroadcast(self, cross_tower_ops, distribution):
def _testReductionAndBroadcast(self, cross_device_ops, distribution):
devices = distribution.worker_devices
values = [constant_op.constant(float(d)) for d in range(len(devices))]
@ -142,24 +142,24 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
# test reduce()
for destinations in all_destinations:
self._assert_values_equal(
cross_tower_ops.reduce(
cross_device_ops.reduce(
reduce_util.ReduceOp.MEAN,
per_replica,
destinations=destinations),
_fake_mirrored(mean, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
cross_device_ops.reduce(
reduce_util.ReduceOp.MEAN,
per_replica_2,
destinations=destinations),
_fake_mirrored(mean_2, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
cross_device_ops.reduce(
reduce_util.ReduceOp.SUM, per_replica,
destinations=destinations),
_fake_mirrored(mean * len(devices), destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
cross_device_ops.reduce(
reduce_util.ReduceOp.SUM,
per_replica_2,
destinations=destinations),
@ -168,7 +168,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
self._assert_values_equal(
cross_tower_ops.batch_reduce(
cross_device_ops.batch_reduce(
reduce_util.ReduceOp.MEAN,
[(per_replica, d1), (per_replica_2, d2)]),
[
@ -176,7 +176,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
_fake_mirrored(mean_2, d2)
])
self._assert_values_equal(
cross_tower_ops.batch_reduce(
cross_device_ops.batch_reduce(
reduce_util.ReduceOp.SUM,
[(per_replica, d1), (per_replica_2, d2)]),
[
@ -187,7 +187,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
# test broadcast()
for destinations in all_destinations:
self._assert_values_equal(
cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
cross_device_ops.broadcast(constant_op.constant(1.), destinations),
_fake_mirrored(1., destinations))
@ -196,17 +196,17 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
# combinations module so that we can pass in devices instead of a distribution
# strategy.
reduction_to_one_combinations = combinations.combine(
cross_tower_ops=[
cross_device_ops=[
combinations.NamedObject(
"DefaultReductionToOneDeviceCrossDeviceOps",
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
combinations.NamedObject(
"ReductionToCPUDeviceCrossDeviceOps",
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps(
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps(
reduce_to_device=_cpu_device)),
combinations.NamedObject(
"AccumulateNCrossDeviceOp",
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps(
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps(
accumulation_fn=math_ops.accumulate_n)),
],
distribution=[
@ -218,20 +218,20 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
],
mode=["graph", "eager"])
allreduce_combinations = combinations.combine(
cross_tower_ops=[
cross_device_ops=[
combinations.NamedObject(
"AllReduce",
cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)),
cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)),
combinations.NamedObject(
"HierarchicalCopy",
cross_tower_ops_lib.AllReduceCrossDeviceOps(
cross_device_ops_lib.AllReduceCrossDeviceOps(
"hierarchical_copy", 8, 0, 0)),
combinations.NamedObject(
"AllReduceNoGradientRepacking",
cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)),
cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)),
combinations.NamedObject(
"HierarchicalCopyAggregateSmallTensors",
cross_tower_ops_lib.AllReduceCrossDeviceOps(
cross_device_ops_lib.AllReduceCrossDeviceOps(
"hierarchical_copy", 0, 100, 10))
],
distribution=[combinations.mirrored_strategy_with_two_gpus,
@ -239,22 +239,22 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
mode=["graph", "eager"])
@combinations.generate(reduction_to_one_combinations + allreduce_combinations)
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
def testReductionAndBroadcast(self, cross_device_ops, distribution):
with distribution.scope():
self._testReductionAndBroadcast(cross_tower_ops, distribution)
self._testReductionAndBroadcast(cross_device_ops, distribution)
def testChooseAlgorithm(self):
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
self.assertEqual(result._num_packs, 8)
# if there are only 4 devices
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
self.assertEqual(result._all_reduce_alg, "nccl")
self.assertEqual(result._num_packs, 1)
@ -262,16 +262,16 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
[0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
[2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
self.assertEqual(result._num_packs, 8)
# if not dgx1-like links
device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
[1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps)
result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps)
self.assertEqual(result._all_reduce_alg, "nccl")
self.assertEqual(result._num_packs, 1)
@ -283,7 +283,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
result = cross_tower_ops_lib._simple_reduce(
result = cross_device_ops_lib._simple_reduce(
per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM)
# Test that the result is semantically equal to both the concatenated
@ -297,19 +297,19 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
@combinations.generate(
combinations.combine(
cross_tower_ops_instance=[
cross_device_ops_instance=[
combinations.NamedObject(
"ReductionToOneDeviceCrossDeviceOps",
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()),
combinations.NamedObject(
"AllReduceCrossDeviceOps",
cross_tower_ops_lib.AllReduceCrossDeviceOps())
cross_device_ops_lib.AllReduceCrossDeviceOps())
],
reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN],
batch_reduce=[True, False],
mode=["graph", "eager"],
required_gpus=1))
def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, reduce_op,
def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op,
batch_reduce):
devices = ["/cpu:0", "/gpu:0"]
dense_shape = [5, 2]
@ -319,10 +319,10 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1})
if batch_reduce:
result = cross_tower_ops_instance.batch_reduce(
result = cross_device_ops_instance.batch_reduce(
reduce_op, [(per_replica, devices)])
else:
result = cross_tower_ops_instance.reduce(
result = cross_device_ops_instance.reduce(
reduce_op, per_replica, devices)
total_indices_with_dups = [1, 1, 3]
@ -359,22 +359,22 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
"/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
]
multi_worker_allreduce_combinations = combinations.combine(
cross_tower_ops=[
cross_device_ops=[
combinations.NamedObject(
"MultiWorkerAllReduce",
cross_tower_ops_lib.MultiWorkerAllReduce(
cross_device_ops_lib.MultiWorkerAllReduce(
worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)),
combinations.NamedObject(
"MultiWorkerAllReducePack",
cross_tower_ops_lib.MultiWorkerAllReduce(
cross_device_ops_lib.MultiWorkerAllReduce(
worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)),
combinations.NamedObject(
"MultiWorkerAllReduceAggregation",
cross_tower_ops_lib.MultiWorkerAllReduce(
cross_device_ops_lib.MultiWorkerAllReduce(
worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)),
combinations.NamedObject(
"MultiWorkerAllReduceMultipleSpecs",
cross_tower_ops_lib.MultiWorkerAllReduce(
cross_device_ops_lib.MultiWorkerAllReduce(
worker_devices, 2, [("pscpu/pscpu", 2, 100),
("xring", 2, -1)], 0, 0, 0)),
],
@ -395,13 +395,13 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
mode=["graph"])
@combinations.generate(multi_worker_allreduce_combinations)
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
def testReductionAndBroadcast(self, cross_device_ops, distribution):
distribution.configure(cluster_spec={
"worker":
["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"]
})
with distribution.scope():
self._testReductionAndBroadcast(cross_tower_ops, distribution)
self._testReductionAndBroadcast(cross_device_ops, distribution)
class MultiWorkerCollectiveAllReduceTest(
@ -422,7 +422,7 @@ class MultiWorkerCollectiveAllReduceTest(
MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000
def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False):
collective_keys = cross_tower_utils.CollectiveKeys(
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
MultiWorkerCollectiveAllReduceTest.collective_key_base,
instance_key_start=num_gpus * 100 +
@ -430,7 +430,7 @@ class MultiWorkerCollectiveAllReduceTest(
instance_key_with_id_start=num_gpus * 10000 +
MultiWorkerCollectiveAllReduceTest.collective_key_base)
if local_mode:
collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
1, num_gpus, collective_keys=collective_keys)
if num_gpus:
devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
@ -438,7 +438,7 @@ class MultiWorkerCollectiveAllReduceTest(
devices = ["/device:CPU:0"]
return collective_all_reduce_ops, devices, ""
else:
collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
3, num_gpus, collective_keys=collective_keys)
if num_gpus:
devices = [

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cross_tower_utils."""
"""Tests for cross_device_utils."""
from __future__ import absolute_import
from __future__ import division
@ -21,8 +21,8 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@ -43,7 +43,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1])
self._assert_values_equal(total, result)
@test_util.run_in_graph_and_eager_modes
@ -53,7 +53,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
t1 = math_ops._as_indexed_slices(
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1])
self.assertIsInstance(result, ops.IndexedSlices)
self._assert_values_equal(total, result)
@ -62,7 +62,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
n = 2
expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n)
self._assert_values_equal(expected, result)
@test_util.run_in_graph_and_eager_modes
@ -71,7 +71,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
n = 2
expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n)
self.assertIsInstance(result, ops.IndexedSlices)
self._assert_values_equal(expected, result)
@ -79,7 +79,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
def testIsIndexedSlices(self):
t = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
self.assertTrue(cross_tower_utils.contains_indexed_slices(t))
self.assertTrue(cross_device_utils.contains_indexed_slices(t))
@test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_List(self):
@ -87,7 +87,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
t1 = math_ops._as_indexed_slices(
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1]))
self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1]))
@test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_Tuple(self):
@ -95,7 +95,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
t1 = math_ops._as_indexed_slices(
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1)))
@test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_PerReplica(self):
@ -104,7 +104,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
t1 = math_ops._as_indexed_slices(
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1})
self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica))
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
@ -113,7 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
with ops.device("/cpu:0"):
t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
destination = "/gpu:0"
result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
result = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
t, destination)
self._assert_values_equal(t, result)
@ -128,7 +128,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
t = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
destination = "/gpu:0"
result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
result = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
t, destination)
self.assertIsInstance(result, ops.IndexedSlices)

View File

@ -24,9 +24,9 @@ import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import values
from tensorflow.python.estimator import keras as keras_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.framework import constant_op

View File

@ -22,12 +22,12 @@ import contextlib
from functools import partial
import threading
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
from tensorflow.contrib.distribute.python import values
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import shared_variable_creator
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
@ -196,16 +196,16 @@ def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
if reduce_op == reduce_util.ReduceOp.MEAN:
return value
cross_tower_ops_lib.validate_destinations(destinations)
cross_device_ops_lib.validate_destinations(destinations)
# We do not support a reduce op of SUM if the value is the same across
# all replicas. We call this as part of assign functions for MirroredVariables
# and summing up identical values across replicas is not clearly defined.
if (len(extended.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
not cross_device_ops_lib.check_destinations(destinations)):
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
"the given reduce op %s." % (value, reduce_op))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_tower_ops_lib.get_devices_from(destinations)
devices = cross_device_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
with ops.device(devices[0]):
return array_ops.identity(value)
@ -369,8 +369,7 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
cross_device_ops=None,
auto_shard_dataset=False):
super(CoreMirroredExtended, self).__init__(container_strategy)
# TODO(josh11b): Rename self._cross_tower_ops -> self._cross_device_ops
self._cross_tower_ops = cross_device_ops
self._cross_device_ops = cross_device_ops
self._auto_shard_dataset = auto_shard_dataset
# Remember num GPUs which might be needed by `configure` method.
if num_gpus is not None and num_gpus_per_worker is not None:
@ -588,7 +587,7 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
if isinstance(tensor, (float, int)): # Fast path for Python constants.
return tensor
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
return self._get_cross_tower_ops().broadcast(
return self._get_cross_device_ops().broadcast(
tensor, destinations or self._devices)
def _call_for_each_replica(self, fn, args, kwargs):
@ -607,26 +606,27 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
if cluster_spec:
self._initialize_multi_worker(self._num_gpus, cluster_spec)
if self._cross_tower_ops is None:
if self._cross_device_ops is None:
if self._cluster_spec:
# It currently cannot detect the toplogy of remote workers. So we
# hard-code the multi-worker all-reduce algorithm for now.
if len(self._workers) == 1:
# The default is "nccl".
self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossDeviceOps()
self._cross_device_ops = (
cross_device_ops_lib.AllReduceCrossDeviceOps())
else:
# The default is hierarchical reduce and broadcast.
self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
self._cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
self._workers, self._num_gpus)
else:
self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
self._cross_device_ops = cross_device_ops_lib.choose_the_best(
self._devices, session_config=session_config)
def _get_cross_tower_ops(self):
if self._cross_tower_ops is None:
self._cross_tower_ops = (
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps())
return self._cross_tower_ops
def _get_cross_device_ops(self):
if self._cross_device_ops is None:
self._cross_device_ops = (
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps())
return self._cross_device_ops
def _reduce_to(self, reduce_op, value, destinations):
assert not isinstance(value, values.Mirrored)
@ -637,12 +637,12 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
# be 0.
return _reduce_non_distributed_value(self, reduce_op, value,
destinations)
return self._get_cross_tower_ops().reduce(
return self._get_cross_device_ops().reduce(
reduce_op, value, destinations=destinations)
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
return self._get_cross_tower_ops().batch_reduce(reduce_op,
value_destination_pairs)
return self._get_cross_device_ops().batch_reduce(reduce_op,
value_destination_pairs)
def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device.
@ -723,7 +723,7 @@ class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
if colocate_with is None:
return self._devices
else:
return cross_tower_ops_lib.get_devices_from(colocate_with)
return cross_device_ops_lib.get_devices_from(colocate_with)
class _MirroredReplicaThread(threading.Thread):
"""A thread that runs() a function on a device."""

View File

@ -25,10 +25,10 @@ import numpy as np
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import six
from tensorflow.contrib.distribute.python import values
from tensorflow.python.distribute import values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops

View File

@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@ -107,8 +107,8 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
self._initialize_local(num_gpus_per_worker)
# We typically don't need to do all-reduce in this strategy.
self._cross_tower_ops = (
cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps(
self._cross_device_ops = (
cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps(
reduce_to_device=_LOCAL_CPU))
def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
@ -256,9 +256,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
True)
def _broadcast_to(self, tensor, destinations):
if not cross_tower_ops_lib.check_destinations(destinations):
if not cross_device_ops_lib.check_destinations(destinations):
destinations = self._compute_devices
return self._cross_tower_ops.broadcast(tensor, destinations)
return self._cross_device_ops.broadcast(tensor, destinations)
def _allow_variable_partition(self):
return not context.executing_eagerly()
@ -330,7 +330,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
return
if destinations is None:
return
for d in cross_tower_ops_lib.get_devices_from(destinations):
for d in cross_device_ops_lib.get_devices_from(destinations):
d_spec = tf_device.DeviceSpec.from_string(d)
if d_spec.job == self._task_type and d_spec.task != self._task_id:
raise ValueError(
@ -343,14 +343,14 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
self, reduce_op, value, destinations)
return self._cross_tower_ops.reduce(
return self._cross_device_ops.reduce(
reduce_op, value, destinations=destinations)
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
for _, destinations in value_destination_pairs:
self._verify_destinations_not_different_worker(destinations)
return self._cross_tower_ops.batch_reduce(reduce_op,
value_destination_pairs)
return self._cross_device_ops.batch_reduce(reduce_op,
value_destination_pairs)
def _select_single_value(self, structured):
"""Select any single values in `structured`."""

View File

@ -26,10 +26,10 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config

View File

@ -23,15 +23,15 @@ from __future__ import print_function
import functools
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import values
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
@ -467,7 +467,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
# Validate that the destination is same as the host device
# Note we don't do this when in replicate context as the reduction is
# performed on the TPU device itself.
devices = cross_tower_ops_lib.get_devices_from(destinations)
devices = cross_device_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
self._host_device)

View File

@ -22,9 +22,9 @@ import os
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.estimator import model_fn as model_fn_lib

View File

@ -8,6 +8,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "all_reduce",
@ -44,13 +45,41 @@ tf_py_test(
)
py_library(
name = "distribute",
name = "cross_device_ops",
srcs = ["cross_device_ops.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":distribute_config",
":distribute_coordinator",
":distribute_coordinator_context",
":cross_device_utils",
":reduce_util",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:device_lib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"@six_archive//:six",
],
)
py_library(
name = "cross_device_utils",
srcs = ["cross_device_utils.py"],
srcs_version = "PY2AND3",
deps = [
":all_reduce",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops",
"//tensorflow/python:device",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops",
],
)
@ -123,6 +152,34 @@ py_library(
],
)
py_library(
name = "input_ops",
srcs = ["input_ops.py"],
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/util:nest",
],
)
cuda_py_test(
name = "input_ops_test",
srcs = ["input_ops_test.py"],
additional_deps = [
":input_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python:errors",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:io_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python:util",
],
tags = [
"no_pip",
],
)
py_test(
name = "multi_worker_util_test",
srcs = ["multi_worker_util_test.py"],
@ -158,8 +215,43 @@ py_library(
py_library(
name = "reduce_util",
srcs = [
"reduce_util.py",
],
srcs = ["reduce_util.py"],
deps = [],
)
py_library(
name = "shared_variable_creator",
srcs = ["shared_variable_creator.py"],
)
py_test(
name = "shared_variable_creator_test",
srcs = ["shared_variable_creator_test.py"],
srcs_version = "PY2AND3",
deps = [
":shared_variable_creator",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:test",
],
)
py_library(
name = "values",
srcs = ["values.py"],
deps = [
":input_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
],
)

View File

@ -21,10 +21,10 @@ from __future__ import print_function
import collections
import six
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.client import device_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -144,7 +144,7 @@ def _simple_broadcast(value, destinations):
index = {}
devices = get_devices_from(destinations)
for d in devices:
index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
index[d] = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
value, d)
return value_lib.Mirrored(index)
@ -162,10 +162,10 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
with ops.device(reduce_to_device):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
all_values, accumulation_fn)
if reduce_op == reduce_util.ReduceOp.MEAN:
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
reduced, count)
elif reduce_op != reduce_util.ReduceOp.SUM:
raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
@ -332,7 +332,7 @@ def _ungroup_and_make_mirrored(grouped_reduced,
Args:
grouped_reduced: a list of lists, each sublist has components for each
device, paired with a None. It is the result from
cross_tower_utils.aggregate_gradients_using*.
cross_device_utils.aggregate_gradients_using*.
destinations: a list of device strings for returned Mirrored objects.
reduce_op: Indicates how values will be aggregated. Accepted values
are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
@ -485,7 +485,7 @@ class AggregateSmallTensorPacker(object):
"""Aggregate small tensors."""
if (self.agg_small_grads_max_bytes > 0 and
self.agg_small_grads_max_group > 0):
device_grads, self.packing = cross_tower_utils.pack_small_tensors(
device_grads, self.packing = cross_device_utils.pack_small_tensors(
grouped_grads_and_vars,
max_bytes=self.agg_small_grads_max_bytes,
max_group=self.agg_small_grads_max_group)
@ -493,8 +493,8 @@ class AggregateSmallTensorPacker(object):
def unpack(self, summed_device_grad_packs):
"""Reverse the aggregation process."""
return cross_tower_utils.unpack_small_tensors(summed_device_grad_packs,
self.packing)
return cross_device_utils.unpack_small_tensors(summed_device_grad_packs,
self.packing)
def _pack_tensors(device_grads,
@ -557,7 +557,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
super(AllReduceCrossDeviceOps, self).__init__()
def _reduce(self, reduce_op, per_replica_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
contains_indexed_slices = cross_device_utils.contains_indexed_slices(
per_replica_value)
if (_devices_match(per_replica_value, destinations)
and not context.executing_eagerly()
@ -580,7 +580,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
def _batch_reduce(self, reduce_op, value_destination_pairs):
all_devices_match = _all_devices_match(value_destination_pairs)
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
contains_indexed_slices = cross_device_utils.contains_indexed_slices(
value_destination_pairs)
if (all_devices_match and not context.executing_eagerly()
and not contains_indexed_slices):
@ -618,13 +618,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
# the balance on num_splits.
if self._all_reduce_alg == "nccl":
# TODO(yuefengz): merge this into the all-reduce library.
reduced = cross_tower_utils.aggregate_gradients_using_nccl(
reduced = cross_device_utils.aggregate_gradients_using_nccl(
device_grad_packs)
else:
# TODO(yuefengz): check that gpu ids in `destinations` are in ascending
# order.
reduced = (
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
cross_device_utils.aggregate_gradients_using_hierarchical_copy(
destinations, device_grad_packs))
reduced = _unpack_tensors(reduced, tensor_packer)
@ -740,13 +740,13 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
this_grads = remaining_grads
remaining_grads = []
else:
(this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size(
(this_grads, remaining_grads) = cross_device_utils.split_grads_by_size(
spec_tuple.limit, remaining_grads)
if this_grads:
device_grad_packs, tensor_packer = _pack_tensors(
this_grads, self._num_packs, self._agg_small_grads_max_bytes,
self._agg_small_grads_max_group)
range_agg_grads = cross_tower_utils.sum_gradients_all_reduce(
range_agg_grads = cross_device_utils.sum_gradients_all_reduce(
self._worker_devices, device_grad_packs, len(self._worker_devices),
spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
@ -789,13 +789,13 @@ class CollectiveAllReduce(CrossDeviceOps):
self._num_workers = num_workers
self._num_gpus_per_worker = num_gpus_per_worker
self._all_reduce_merge_scope = all_reduce_merge_scope
self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys(
)
self._collective_keys = (collective_keys or
cross_device_utils.CollectiveKeys())
super(CollectiveAllReduce, self).__init__()
# TODO(yuefengz, tucker): is indexed slices supported by collective ops?
def _reduce(self, reduce_op, per_replica_value, destinations):
if cross_tower_utils.contains_indexed_slices(per_replica_value):
if cross_device_utils.contains_indexed_slices(per_replica_value):
raise ValueError(
"`IndexSlices` is not supported for Collective All-Reduce.")
if context.executing_eagerly():
@ -819,7 +819,7 @@ class CollectiveAllReduce(CrossDeviceOps):
return value_lib.Mirrored(index)
def _batch_reduce(self, reduce_op, value_destination_pairs):
if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
if cross_device_utils.contains_indexed_slices(value_destination_pairs):
raise ValueError(
"`IndexSlices` is not supported for Collective All-Reduce.")
if context.executing_eagerly():
@ -870,7 +870,7 @@ class CollectiveAllReduce(CrossDeviceOps):
with ops.name_scope("allreduce"):
for grad_and_vars in chunk:
scaled_grads = [g for g, _ in grad_and_vars]
collective_reduced = cross_tower_utils.build_collective_reduce(
collective_reduced = cross_device_utils.build_collective_reduce(
scaled_grads, self._num_workers, self._collective_keys, "Add",
"Id")
result = []

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for cross_tower_ops."""
"""Utilities for cross_device_ops."""
from __future__ import absolute_import
from __future__ import division
@ -21,8 +21,8 @@ from __future__ import print_function
import collections as pycoll
import threading
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.distribute import all_reduce
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import os
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.distribute import input_ops
from tensorflow.python.framework import errors
from tensorflow.python.lib.io import python_io
from tensorflow.python.platform import test

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distribute.python import shared_variable_creator
from tensorflow.python.distribute import shared_variable_creator
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope

View File

@ -27,9 +27,9 @@ import operator
import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape