STT-tensorflow/tensorflow/python/distribute/all_reduce_test.py
Gaurav Jain 016bb3fc99 Remove unnecessary eval() calls
The assertAll* statements already evaluate the arguments.

PiperOrigin-RevId: 319130109
Change-Id: I0034bdfa87a974613561a39e0d4a4223292245c7
2020-06-30 17:18:32 -07:00

243 lines
10 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for all_reduce."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.core.framework import types_pb2
from tensorflow.python.distribute import all_reduce as ar
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
class AllReduceTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testFlattenTensorsShapesDefined(self):
x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
with self.assertRaisesRegex(ValueError, "must have statically known shape"):
ar._flatten_tensors([x, x])
def testRingPermutations(self):
# 0 devices
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, [])
self.assertEqual(pred_by_c_d, [])
self.assertEqual(rank_by_c_d, [])
# 1 worker, 1 subchunk cases
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0])
self.assertEqual(pred_by_c_d, [[0]])
self.assertEqual(rank_by_c_d, [[0]])
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0, 1, 2])
self.assertEqual(pred_by_c_d, [[2, 0, 1]])
self.assertEqual(rank_by_c_d, [[0, 1, 2]])
# multiple workers, 1 subchunk cases
pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [0, 1, 2])
self.assertEqual(pred_by_c_d, [[5, 0, 1, 2, 3, 4]])
self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5]])
pred_by_c_d, rank_by_c_d = ar._ring_permutations(3, 1, [0, 1, 2])
self.assertEqual(pred_by_c_d, [[8, 0, 1, 2, 3, 4, 5, 6, 7]])
self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7, 8]])
pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [2, 1, 0])
self.assertEqual(pred_by_c_d, [[1, 2, 3, 4, 5, 0]])
self.assertEqual(rank_by_c_d, [[2, 1, 0, 5, 4, 3]])
# 1 worker, multiple subchunk cases
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3])
self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2]])
self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [2, 3, 0, 1]])
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 4, [0, 1, 2, 3])
self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2],
[3, 0, 1, 2], [3, 0, 1, 2]])
self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [3, 0, 1, 2],
[2, 3, 0, 1], [1, 2, 3, 0]])
# multiple worker, multiple subchunk cases
pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 1, 2, 3])
self.assertEqual(pred_by_c_d, [[7, 0, 1, 2, 3, 4, 5, 6],
[3, 0, 5, 2, 7, 4, 1, 6]])
self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7],
[2, 3, 0, 1, 6, 7, 4, 5]])
pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 3, 2, 1])
self.assertEqual(pred_by_c_d, [[5, 2, 3, 0, 1, 6, 7, 4],
[1, 2, 7, 0, 5, 6, 3, 4]])
self.assertEqual(rank_by_c_d, [[0, 3, 2, 1, 4, 7, 6, 5],
[2, 1, 0, 3, 6, 5, 4, 7]])
def _buildInput(self, num_workers, num_gpus):
t8 = constant_op.constant(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
types_pb2.DT_FLOAT)
input_tensors = []
device_names = []
for w in range(0, num_workers):
for d in range(0, num_gpus):
dn = "/replica:0/task:%d/device:GPU:%d" % (w, d % num_gpus)
device_names.append(dn)
with ops.device(dn):
input_tensors.append(array_ops.identity(t8))
return input_tensors, device_names
@test_util.run_deprecated_v1
def testBuildRingGatherPassStructure(self):
# 1 worker, 1 device
input_tensors, device_names = self._buildInput(1, 1)
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0])
output_tensors = ar._build_ring_gather(input_tensors, device_names, 1,
pred_by_c_d, rank_by_c_d,
math_ops.add)
self.assertEqual(output_tensors, input_tensors)
# 1 worker, 4 devices, 2 subchunks
input_tensors, device_names = self._buildInput(1, 4)
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3])
output_tensors, pad_len = ar._build_ring_gather(
input_tensors, device_names, 2, pred_by_c_d, rank_by_c_d, math_ops.add)
self.assertEqual(0, pad_len)
# same number outputs as inputs
self.assertEqual(len(output_tensors), len(input_tensors))
num_chunks = 2 * len(input_tensors)
tlen = tensor_shape.dimension_value(input_tensors[0].shape[0])
for otl in output_tensors:
self.assertEqual(len(otl), num_chunks)
for ot in otl:
self.assertEqual(ot.shape, [tlen//num_chunks])
def _buildInitialVars(self, shape, dev_list):
values = []
num_devices = len(dev_list)
dim = np.prod(shape, dtype=int) if shape else 1
for d in range(0, num_devices):
with ops.device(dev_list[d]):
npt = np.zeros(shape).astype(np.float32)
alias = np.frombuffer(npt.data, dtype=np.float32)
for i in range(0, dim):
alias[i] = i + 0.01 * d
var = state_ops.variable_op(shape, types_pb2.DT_FLOAT)
state_ops.init_variable(var, npt).op.run()
values.append(var)
return values
# pylint: disable=g-long-lambda
def _buildRing(self, num_workers, num_gpus, subdiv):
gpu_perm = range(0, num_gpus)
return lambda x, un_op: ar.build_ring_all_reduce(
x, num_workers, subdiv, gpu_perm, math_ops.add, un_op)
def _testAllReduce(self, num_workers, num_gpus, shape, build_f):
# Use local CPU as device for all inputs.
num_devices = num_workers * num_gpus
dev_list = ["/replica:0/task:0/device:CPU:0"
for _ in range(num_devices)]
with self.cached_session():
input_tensors = self._buildInitialVars(shape, dev_list)
un_op = lambda x: math_ops.div(
x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT))
simple_sum = math_ops.add_n(input_tensors)
simple_sum.op.run()
output_tensors = build_f(input_tensors, un_op)
sum_reduced = math_ops.add_n(output_tensors)
sum_reduced.op.run()
self.assertAllClose(sum_reduced, self.evaluate(simple_sum))
def _testRingAllReduce(self, num_workers, num_gpus, shape, subdiv):
start_time = time.time()
build_f = self._buildRing(num_workers, num_gpus, subdiv)
self._testAllReduce(num_workers, num_gpus, shape, build_f)
elapsed = time.time() - start_time
tf_logging.info("RingAllReduce num_workers=%d num_gpus=%d shape=%s "
"subdiv=%d elapsed=%f" %
(num_workers, num_gpus, shape, subdiv, elapsed))
@test_util.run_deprecated_v1
def testRingAllReduce(self):
self._testRingAllReduce(1, 2, [], 1)
self._testRingAllReduce(1, 2, [8], 1)
self._testRingAllReduce(1, 2, [4, 4], 1)
self._testRingAllReduce(6, 1, [8], 1)
self._testRingAllReduce(1, 8, [32], 1)
self._testRingAllReduce(1, 8, [120], 1)
self._testRingAllReduce(2, 8, [7, 13], 1)
self._testRingAllReduce(2, 8, [8, 8], 2)
self._testRingAllReduce(2, 8, [8, 8], 4)
# TODO(tucker): The following test is surprisingly slow.
# Diagnose and fix before re-enabling.
# self._testRingAllReduce(4, 8, [8, 8, 2], 4)
def _buildShuffle(self, num_workers, num_gpus, num_shards):
# Use local CPU for all shuffle shards
gather_devices = ["/replica:0/task:0/device:CPU:0"
for _ in range(num_shards)]
return lambda x, un_op: ar.build_shuffle_all_reduce(
x, gather_devices, math_ops.add_n, un_op)
def _testShuffleAllReduce(self, num_workers, num_gpus, shape, num_shards):
start_time = time.time()
build_f = self._buildShuffle(num_workers, num_gpus, num_shards)
self._testAllReduce(num_workers, num_gpus, shape, build_f)
elapsed = time.time() - start_time
tf_logging.info("ShuffleAllReduce num_workers=%d num_gpus=%d shape=%s "
"elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
@test_util.run_deprecated_v1
def testShuffleAllReduce(self):
self._testShuffleAllReduce(1, 2, [], 1)
self._testShuffleAllReduce(1, 2, [8], 1)
self._testShuffleAllReduce(1, 2, [4, 4], 1)
self._testShuffleAllReduce(1, 8, [32], 1)
self._testShuffleAllReduce(1, 8, [120], 1)
self._testShuffleAllReduce(2, 8, [7, 13], 3)
self._testShuffleAllReduce(2, 8, [8, 8], 2)
self._testShuffleAllReduce(2, 8, [8, 8], 4)
self._testShuffleAllReduce(4, 8, [8, 8, 2], 4)
def _buildRecursiveHD(self, num_workers, num_gpus):
return lambda x, un_op: ar.build_recursive_hd_all_reduce(
x, math_ops.add, un_op)
# pylint: enable=g-long-lambda
def _testRecursiveHDAllReduce(self, num_workers, num_gpus, shape):
start_time = time.time()
build_f = self._buildRecursiveHD(num_workers, num_gpus)
self._testAllReduce(num_workers, num_gpus, shape, build_f)
elapsed = time.time() - start_time
tf_logging.info("RecursiveHDAllReduce num_workers=%d num_gpus=%d "
"shape=%s elapsed=%f" %
(num_workers, num_gpus, shape, elapsed))
@test_util.run_deprecated_v1
def testRecursiveHDAllReduce(self):
self._testRecursiveHDAllReduce(1, 2, [8])
self._testRecursiveHDAllReduce(1, 2, [4, 4])
self._testRecursiveHDAllReduce(1, 8, [32])
self._testRecursiveHDAllReduce(1, 8, [120])
self._testRecursiveHDAllReduce(2, 8, [8, 8])
self._testRecursiveHDAllReduce(4, 8, [8, 8, 2])
if __name__ == "__main__":
test.main()