Delete Map/Reduce support from OneDevice and Mirrored DistributionStrategies.

PiperOrigin-RevId: 221265603
This commit is contained in:
A. Unique TensorFlower 2018-11-13 07:47:28 -08:00 committed by TensorFlower Gardener
parent a326fdb402
commit 994ee1d866
10 changed files with 3 additions and 87 deletions

View File

@ -155,17 +155,7 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
all_values = []
count = 0
for v in per_replica_value._index.values(): # pylint: disable=protected-access
if isinstance(v, value_lib.MapOutput):
v_list = v.get()
if not v_list:
continue
count += len(v_list)
# Sum within each device before aggregating across devices.
# TODO(yuefengz): Check whether it helps to use accumulation_fn here.
v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
v_list, math_ops.add_n)
else:
count += 1
count += 1
all_values.append(v)
if not all_values:
raise ValueError("`per_replica_value` must be non-empty")

View File

@ -667,7 +667,5 @@ def contains_indexed_slices(value):
return any(contains_indexed_slices(v) for v in value)
elif isinstance(value, value_lib.DistributedValues):
return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access
elif isinstance(value, value_lib.MapOutput):
return contains_indexed_slices(value.get())
else:
return False

View File

@ -106,17 +106,6 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1})
self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
@test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_PerReplicaMapOutput(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
t1 = math_ops._as_indexed_slices(
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
per_replica = value_lib.PerReplica({
"/gpu:0": value_lib.MapOutput([t0]),
"/cpu:0": value_lib.MapOutput([t1])})
self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
required_gpus=1))

View File

@ -554,21 +554,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _call_for_each_replica(self, fn, args, kwargs):
return _call_for_each_replica(self, fn, args, kwargs)
def map(self, map_over, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
index = {}
for i, m in enumerate(map_over):
d = self._devices[i % len(self._devices)]
with ops.device(d):
l = index.get(d, [])
l.append(fn(m,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs)))
index[d] = l
# TODO(josh11b): Need a values.regroup equivalent that handles MapOutput
# in addition to PerReplica data.
return values.PerReplica({k: values.MapOutput(v) for k, v in index.items()})
def configure(self,
session_config=None,
cluster_spec=None,

View File

@ -78,11 +78,6 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
self._test_minimize_loss_graph(
self._get_distribution_strategy(), soft_placement=soft_placement)
def testMapReduce(self):
if not GPU_TEST:
self.skipTest("Not GPU test")
self._test_map_reduce(self._get_distribution_strategy())
def testDeviceIndex(self):
if not GPU_TEST:
self.skipTest("Not GPU test")

View File

@ -40,9 +40,6 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
def testMinimizeLossGraph(self):
self._test_minimize_loss_graph(self._get_distribution_strategy())
def testMapReduce(self):
self._test_map_reduce(self._get_distribution_strategy())
def testDeviceIndex(self):
self._test_device_index(self._get_distribution_strategy())

View File

@ -25,8 +25,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
@ -119,23 +117,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.device(self._device), _OneDeviceReplicaContext(self):
return fn(*args, **kwargs)
def map(self, map_over, fn, *args, **kwargs):
with ops.device(self._device):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
def _reduce(self, aggregation, value, destinations):
del destinations
if not isinstance(value, values.MapOutput):
return value
l = value.get()
assert l
with ops.device(self._device):
if aggregation == vs.VariableAggregation.SUM:
return math_ops.add_n(l)
elif aggregation == vs.VariableAggregation.MEAN:
return math_ops.add_n(l) / len(l)
else:
assert False
del aggregation, destinations
return value
def _update(self, var, options, fn, *args, **kwargs):
# The implementations of _update() and _update_non_slot() are identical

View File

@ -35,9 +35,6 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
def testMinimizeLossGraph(self):
self._test_minimize_loss_graph(self._get_distribution_strategy())
def testMapReduce(self):
self._test_map_reduce(self._get_distribution_strategy())
def testDeviceIndex(self):
self._test_device_index(self._get_distribution_strategy())

View File

@ -189,15 +189,6 @@ class DistributionTestBase(test.TestCase):
# Error should go down
self.assertLess(error_after, error_before)
def _test_map_reduce(self, d, in_graph=None):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out,
"/device:CPU:0")
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
def _test_device_index(self, d):
with d.scope():
expected_devices = [False] * len(d.worker_devices)

View File

@ -1287,16 +1287,6 @@ class MultiWorkerDataset(object):
return MultiWorkerDataIterator(iterators, self._worker_device_pairs)
class MapOutput(object):
"""Map can result in multiple outputs per device."""
def __init__(self, l):
self._l = l
def get(self):
return self._l
class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.