Delete Map/Reduce support from OneDevice and Mirrored DistributionStrategies.
PiperOrigin-RevId: 221265603
This commit is contained in:
parent
a326fdb402
commit
994ee1d866
@ -155,17 +155,7 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
||||
all_values = []
|
||||
count = 0
|
||||
for v in per_replica_value._index.values(): # pylint: disable=protected-access
|
||||
if isinstance(v, value_lib.MapOutput):
|
||||
v_list = v.get()
|
||||
if not v_list:
|
||||
continue
|
||||
count += len(v_list)
|
||||
# Sum within each device before aggregating across devices.
|
||||
# TODO(yuefengz): Check whether it helps to use accumulation_fn here.
|
||||
v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
|
||||
v_list, math_ops.add_n)
|
||||
else:
|
||||
count += 1
|
||||
count += 1
|
||||
all_values.append(v)
|
||||
if not all_values:
|
||||
raise ValueError("`per_replica_value` must be non-empty")
|
||||
|
@ -667,7 +667,5 @@ def contains_indexed_slices(value):
|
||||
return any(contains_indexed_slices(v) for v in value)
|
||||
elif isinstance(value, value_lib.DistributedValues):
|
||||
return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access
|
||||
elif isinstance(value, value_lib.MapOutput):
|
||||
return contains_indexed_slices(value.get())
|
||||
else:
|
||||
return False
|
||||
|
@ -106,17 +106,6 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1})
|
||||
self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testContainsIndexedSlices_PerReplicaMapOutput(self):
|
||||
t0 = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
|
||||
t1 = math_ops._as_indexed_slices(
|
||||
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
|
||||
per_replica = value_lib.PerReplica({
|
||||
"/gpu:0": value_lib.MapOutput([t0]),
|
||||
"/cpu:0": value_lib.MapOutput([t1])})
|
||||
self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
required_gpus=1))
|
||||
|
@ -554,21 +554,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
return _call_for_each_replica(self, fn, args, kwargs)
|
||||
|
||||
def map(self, map_over, fn, *args, **kwargs):
|
||||
# TODO(josh11b): In eager mode, use one thread per device.
|
||||
index = {}
|
||||
for i, m in enumerate(map_over):
|
||||
d = self._devices[i % len(self._devices)]
|
||||
with ops.device(d):
|
||||
l = index.get(d, [])
|
||||
l.append(fn(m,
|
||||
*values.select_device_mirrored(d, args),
|
||||
**values.select_device_mirrored(d, kwargs)))
|
||||
index[d] = l
|
||||
# TODO(josh11b): Need a values.regroup equivalent that handles MapOutput
|
||||
# in addition to PerReplica data.
|
||||
return values.PerReplica({k: values.MapOutput(v) for k, v in index.items()})
|
||||
|
||||
def configure(self,
|
||||
session_config=None,
|
||||
cluster_spec=None,
|
||||
|
@ -78,11 +78,6 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
|
||||
self._test_minimize_loss_graph(
|
||||
self._get_distribution_strategy(), soft_placement=soft_placement)
|
||||
|
||||
def testMapReduce(self):
|
||||
if not GPU_TEST:
|
||||
self.skipTest("Not GPU test")
|
||||
self._test_map_reduce(self._get_distribution_strategy())
|
||||
|
||||
def testDeviceIndex(self):
|
||||
if not GPU_TEST:
|
||||
self.skipTest("Not GPU test")
|
||||
|
@ -40,9 +40,6 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
|
||||
def testMinimizeLossGraph(self):
|
||||
self._test_minimize_loss_graph(self._get_distribution_strategy())
|
||||
|
||||
def testMapReduce(self):
|
||||
self._test_map_reduce(self._get_distribution_strategy())
|
||||
|
||||
def testDeviceIndex(self):
|
||||
self._test_device_index(self._get_distribution_strategy())
|
||||
|
||||
|
@ -25,8 +25,6 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -119,23 +117,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
|
||||
with ops.device(self._device), _OneDeviceReplicaContext(self):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def map(self, map_over, fn, *args, **kwargs):
|
||||
with ops.device(self._device):
|
||||
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
|
||||
|
||||
def _reduce(self, aggregation, value, destinations):
|
||||
del destinations
|
||||
if not isinstance(value, values.MapOutput):
|
||||
return value
|
||||
l = value.get()
|
||||
assert l
|
||||
with ops.device(self._device):
|
||||
if aggregation == vs.VariableAggregation.SUM:
|
||||
return math_ops.add_n(l)
|
||||
elif aggregation == vs.VariableAggregation.MEAN:
|
||||
return math_ops.add_n(l) / len(l)
|
||||
else:
|
||||
assert False
|
||||
del aggregation, destinations
|
||||
return value
|
||||
|
||||
def _update(self, var, options, fn, *args, **kwargs):
|
||||
# The implementations of _update() and _update_non_slot() are identical
|
||||
|
@ -35,9 +35,6 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
|
||||
def testMinimizeLossGraph(self):
|
||||
self._test_minimize_loss_graph(self._get_distribution_strategy())
|
||||
|
||||
def testMapReduce(self):
|
||||
self._test_map_reduce(self._get_distribution_strategy())
|
||||
|
||||
def testDeviceIndex(self):
|
||||
self._test_device_index(self._get_distribution_strategy())
|
||||
|
||||
|
@ -189,15 +189,6 @@ class DistributionTestBase(test.TestCase):
|
||||
# Error should go down
|
||||
self.assertLess(error_after, error_before)
|
||||
|
||||
def _test_map_reduce(self, d, in_graph=None):
|
||||
with d.scope():
|
||||
map_in = [constant_op.constant(i) for i in range(10)]
|
||||
map_out = d.map(map_in, lambda x, y: x * y, 2)
|
||||
observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out,
|
||||
"/device:CPU:0")
|
||||
expected = 90 # 2 * (0 + 1 + ... + 9)
|
||||
self.assertEqual(expected, observed.numpy())
|
||||
|
||||
def _test_device_index(self, d):
|
||||
with d.scope():
|
||||
expected_devices = [False] * len(d.worker_devices)
|
||||
|
@ -1287,16 +1287,6 @@ class MultiWorkerDataset(object):
|
||||
return MultiWorkerDataIterator(iterators, self._worker_device_pairs)
|
||||
|
||||
|
||||
class MapOutput(object):
|
||||
"""Map can result in multiple outputs per device."""
|
||||
|
||||
def __init__(self, l):
|
||||
self._l = l
|
||||
|
||||
def get(self):
|
||||
return self._l
|
||||
|
||||
|
||||
class MultiStepContext(object):
|
||||
"""A context object that can be used to capture things when running steps.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user