Remove values property from DistributedValues.

PiperOrigin-RevId: 295994651
Change-Id: Ic0d003c76e711bee12d5de563de902430e837d5e
This commit is contained in:
Ken Franko 2020-02-19 10:05:28 -08:00 committed by TensorFlower Gardener
parent 8e58b059ef
commit eea4079931
7 changed files with 56 additions and 25 deletions

View File

@ -56,7 +56,6 @@ class MirroredFunctionStrategyTest(test.TestCase):
self.assertLen(f_traces, 1) # Function traced once, not for each replica.
# Returns a per-replica value.
self.assertIsInstance(result1, values.PerReplica)
self.assertAllEqual([1, 2], result1.values)
self.assertAllEqual([1, 2],
self._strategy.experimental_local_results(result1))
@ -64,7 +63,8 @@ class MirroredFunctionStrategyTest(test.TestCase):
result2 = self._strategy.experimental_run_v2(f, args=(result1,))
self.assertLen(f_traces, 1)
self.assertIsInstance(result2, values.PerReplica)
self.assertAllEqual([1, 3], result2.values)
self.assertAllEqual([1, 3],
self._strategy.experimental_local_results(result2))
def testMergeCall(self):
f_traces = []
@ -94,7 +94,8 @@ class MirroredFunctionStrategyTest(test.TestCase):
self.assertLen(g_traces, 1)
# Returns a per-replica value.
self.assertIsInstance(result, values.PerReplica)
self.assertAllEqual([1, 1], result.values)
self.assertAllEqual([1, 1],
self._strategy.experimental_local_results(result))
if __name__ == "__main__":

View File

@ -842,7 +842,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _local_results(self, val):
if isinstance(val, values.DistributedValues):
return val.values
return val._values # pylint: disable=protected-access
return (val,)
def value_container(self, val):

View File

@ -356,7 +356,9 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual((0, 1), self.evaluate(result.values))
self.assertEqual(
(0, 1),
self.evaluate(distribution.experimental_local_results(result)))
self.assertLen(traces, distribution.num_replicas_in_sync)
def testFunctionInCallForEachReplicaInsideAnotherFunction(self, distribution):
@ -372,7 +374,9 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
with distribution.scope():
result = step()
self.assertEqual((0, 1), self.evaluate(result.values))
self.assertEqual(
(0, 1),
self.evaluate(distribution.experimental_local_results(result)))
self.assertLen(traces, distribution.num_replicas_in_sync)
def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution):
@ -711,8 +715,14 @@ class MirroredVariableUpdateTest(test.TestCase):
mirrored_var_result = self.evaluate(
mirrored_var.assign_add(6.0, read_value=True))
self.assertEqual(7.0, mirrored_var_result)
self.assertEqual(7.0, self.evaluate(mirrored_var.values[0]))
self.assertEqual(7.0, self.evaluate(mirrored_var.values[1]))
self.assertEqual(
7.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[0]))
self.assertEqual(
7.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[1]))
self.assertEqual(
distribution.extended.worker_devices[0], mirrored_var._devices[0])
self.assertEqual(
@ -720,8 +730,14 @@ class MirroredVariableUpdateTest(test.TestCase):
# read_value == False
self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
self.assertEqual(9.0, self.evaluate(mirrored_var.values[0]))
self.assertEqual(9.0, self.evaluate(mirrored_var.values[1]))
self.assertEqual(
9.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[0]))
self.assertEqual(
9.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[1]))
self.assertEqual(
distribution.extended.worker_devices[0], mirrored_var._devices[0])
self.assertEqual(
@ -777,8 +793,14 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEqual(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEqual(3.0, mirrored_var_result)
self.assertEqual(3.0, self.evaluate(mirrored_var.values[0]))
self.assertEqual(3.0, self.evaluate(mirrored_var.values[1]))
self.assertEqual(
3.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[0]))
self.assertEqual(
3.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[1]))
self.assertEqual(
distribution.extended.worker_devices[0], mirrored_var._devices[0])
self.assertEqual(
@ -994,7 +1016,8 @@ class MirroredStrategyDefunTest(test.TestCase):
distribution.extended.call_for_each_replica(
defun.get_concrete_function, args=[mock_model] + inputs))
for i in range(len(devices)):
graph_function = per_replica_graph_functions.values[i]
graph_function = distribution.experimental_local_results(
per_replica_graph_functions)[i]
# TODO(b/129555712): re-enable an assertion here that the two sets of
# variables are the same.
# self.assertEqual(set(graph_function.graph.variables),

View File

@ -532,8 +532,10 @@ class MirroredVariableCreationTest(test.TestCase):
expected_mean = 0.0
for i, _ in enumerate(distribution.extended.worker_devices):
# Should see different values on different devices.
v_sum_value = self.evaluate(ret_v_sum.values[i].read_value())
v_mean_value = self.evaluate(ret_v_mean.values[i].read_value())
v_sum_value = self.evaluate(
distribution.experimental_local_results(ret_v_sum)[i].read_value())
v_mean_value = self.evaluate(
distribution.experimental_local_results(ret_v_mean)[i].read_value())
expected = i + 3.0
self.assertEqual(expected, v_sum_value)
expected_sum += expected

View File

@ -92,11 +92,6 @@ class DistributedValues(object):
"""Returns a representative component."""
return self._values[0]
# TODO(josh11b): Replace experimental_local_results with this?
@property
def values(self):
return self._values
@property
def _devices(self):
return tuple(v.device for v in self._values)
@ -139,6 +134,11 @@ class DistributedDelegate(DistributedValues):
# __getattr__ and @property. See b/120402273.
return getattr(self._get(), name)
@property
def values(self):
"""Returns the per replica values."""
return self._values
def _get_as_operand(self):
"""Returns the value for operations for the current device.
@ -272,6 +272,11 @@ class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
return PerReplicaSpec(
*(type_spec.type_spec_from_value(v) for v in self._values))
@property
def values(self):
"""Returns the per replica values."""
return self._values
class PerReplicaSpec(type_spec.TypeSpec):
"""Type specification for a `PerReplica`."""
@ -824,7 +829,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
if update_replica_id is not None:
# We are calling an assign function on the mirrored variable in an
# update context.
return f(self.values[update_replica_id], *args, **kwargs)
return f(self._values[update_replica_id], *args, **kwargs)
# We are calling assign on the mirrored variable in cross replica
# context, use `strategy.extended.update()` to update the variable.

View File

@ -715,9 +715,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
return t
results = strat.extended.call_for_each_replica(
fn=f, args=gens)
values = results.values
self.assertAllEqual(2, len(values))
self.assertAllDifferent(values)
local_results = strat.experimental_local_results(results)
self.assertAllEqual(2, len(local_results))
self.assertAllDifferent(local_results)
if __name__ == "__main__":

View File

@ -75,7 +75,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
def convert_tensor_to_list(tensor):
if isinstance(tensor, values.DistributedValues):
return tensor.values
return strategy.experimental_local_results(tensor)
else:
return [tensor]
return nest.map_structure(convert_tensor_to_list, results)