Remove values property from DistributedValues.
PiperOrigin-RevId: 295994651 Change-Id: Ic0d003c76e711bee12d5de563de902430e837d5e
This commit is contained in:
parent
8e58b059ef
commit
eea4079931
@ -56,7 +56,6 @@ class MirroredFunctionStrategyTest(test.TestCase):
|
||||
self.assertLen(f_traces, 1) # Function traced once, not for each replica.
|
||||
# Returns a per-replica value.
|
||||
self.assertIsInstance(result1, values.PerReplica)
|
||||
self.assertAllEqual([1, 2], result1.values)
|
||||
self.assertAllEqual([1, 2],
|
||||
self._strategy.experimental_local_results(result1))
|
||||
|
||||
@ -64,7 +63,8 @@ class MirroredFunctionStrategyTest(test.TestCase):
|
||||
result2 = self._strategy.experimental_run_v2(f, args=(result1,))
|
||||
self.assertLen(f_traces, 1)
|
||||
self.assertIsInstance(result2, values.PerReplica)
|
||||
self.assertAllEqual([1, 3], result2.values)
|
||||
self.assertAllEqual([1, 3],
|
||||
self._strategy.experimental_local_results(result2))
|
||||
|
||||
def testMergeCall(self):
|
||||
f_traces = []
|
||||
@ -94,7 +94,8 @@ class MirroredFunctionStrategyTest(test.TestCase):
|
||||
self.assertLen(g_traces, 1)
|
||||
# Returns a per-replica value.
|
||||
self.assertIsInstance(result, values.PerReplica)
|
||||
self.assertAllEqual([1, 1], result.values)
|
||||
self.assertAllEqual([1, 1],
|
||||
self._strategy.experimental_local_results(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -842,7 +842,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
def _local_results(self, val):
|
||||
if isinstance(val, values.DistributedValues):
|
||||
return val.values
|
||||
return val._values # pylint: disable=protected-access
|
||||
return (val,)
|
||||
|
||||
def value_container(self, val):
|
||||
|
@ -356,7 +356,9 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
||||
|
||||
with distribution.scope():
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
self.assertEqual((0, 1), self.evaluate(result.values))
|
||||
self.assertEqual(
|
||||
(0, 1),
|
||||
self.evaluate(distribution.experimental_local_results(result)))
|
||||
self.assertLen(traces, distribution.num_replicas_in_sync)
|
||||
|
||||
def testFunctionInCallForEachReplicaInsideAnotherFunction(self, distribution):
|
||||
@ -372,7 +374,9 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
||||
|
||||
with distribution.scope():
|
||||
result = step()
|
||||
self.assertEqual((0, 1), self.evaluate(result.values))
|
||||
self.assertEqual(
|
||||
(0, 1),
|
||||
self.evaluate(distribution.experimental_local_results(result)))
|
||||
self.assertLen(traces, distribution.num_replicas_in_sync)
|
||||
|
||||
def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution):
|
||||
@ -711,8 +715,14 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
mirrored_var_result = self.evaluate(
|
||||
mirrored_var.assign_add(6.0, read_value=True))
|
||||
self.assertEqual(7.0, mirrored_var_result)
|
||||
self.assertEqual(7.0, self.evaluate(mirrored_var.values[0]))
|
||||
self.assertEqual(7.0, self.evaluate(mirrored_var.values[1]))
|
||||
self.assertEqual(
|
||||
7.0,
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(mirrored_var)[0]))
|
||||
self.assertEqual(
|
||||
7.0,
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(mirrored_var)[1]))
|
||||
self.assertEqual(
|
||||
distribution.extended.worker_devices[0], mirrored_var._devices[0])
|
||||
self.assertEqual(
|
||||
@ -720,8 +730,14 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
|
||||
# read_value == False
|
||||
self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
|
||||
self.assertEqual(9.0, self.evaluate(mirrored_var.values[0]))
|
||||
self.assertEqual(9.0, self.evaluate(mirrored_var.values[1]))
|
||||
self.assertEqual(
|
||||
9.0,
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(mirrored_var)[0]))
|
||||
self.assertEqual(
|
||||
9.0,
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(mirrored_var)[1]))
|
||||
self.assertEqual(
|
||||
distribution.extended.worker_devices[0], mirrored_var._devices[0])
|
||||
self.assertEqual(
|
||||
@ -777,8 +793,14 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
self.assertEqual(5.0, self.evaluate(mirrored_var))
|
||||
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
|
||||
self.assertEqual(3.0, mirrored_var_result)
|
||||
self.assertEqual(3.0, self.evaluate(mirrored_var.values[0]))
|
||||
self.assertEqual(3.0, self.evaluate(mirrored_var.values[1]))
|
||||
self.assertEqual(
|
||||
3.0,
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(mirrored_var)[0]))
|
||||
self.assertEqual(
|
||||
3.0,
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(mirrored_var)[1]))
|
||||
self.assertEqual(
|
||||
distribution.extended.worker_devices[0], mirrored_var._devices[0])
|
||||
self.assertEqual(
|
||||
@ -994,7 +1016,8 @@ class MirroredStrategyDefunTest(test.TestCase):
|
||||
distribution.extended.call_for_each_replica(
|
||||
defun.get_concrete_function, args=[mock_model] + inputs))
|
||||
for i in range(len(devices)):
|
||||
graph_function = per_replica_graph_functions.values[i]
|
||||
graph_function = distribution.experimental_local_results(
|
||||
per_replica_graph_functions)[i]
|
||||
# TODO(b/129555712): re-enable an assertion here that the two sets of
|
||||
# variables are the same.
|
||||
# self.assertEqual(set(graph_function.graph.variables),
|
||||
|
@ -532,8 +532,10 @@ class MirroredVariableCreationTest(test.TestCase):
|
||||
expected_mean = 0.0
|
||||
for i, _ in enumerate(distribution.extended.worker_devices):
|
||||
# Should see different values on different devices.
|
||||
v_sum_value = self.evaluate(ret_v_sum.values[i].read_value())
|
||||
v_mean_value = self.evaluate(ret_v_mean.values[i].read_value())
|
||||
v_sum_value = self.evaluate(
|
||||
distribution.experimental_local_results(ret_v_sum)[i].read_value())
|
||||
v_mean_value = self.evaluate(
|
||||
distribution.experimental_local_results(ret_v_mean)[i].read_value())
|
||||
expected = i + 3.0
|
||||
self.assertEqual(expected, v_sum_value)
|
||||
expected_sum += expected
|
||||
|
@ -92,11 +92,6 @@ class DistributedValues(object):
|
||||
"""Returns a representative component."""
|
||||
return self._values[0]
|
||||
|
||||
# TODO(josh11b): Replace experimental_local_results with this?
|
||||
@property
|
||||
def values(self):
|
||||
return self._values
|
||||
|
||||
@property
|
||||
def _devices(self):
|
||||
return tuple(v.device for v in self._values)
|
||||
@ -139,6 +134,11 @@ class DistributedDelegate(DistributedValues):
|
||||
# __getattr__ and @property. See b/120402273.
|
||||
return getattr(self._get(), name)
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
"""Returns the per replica values."""
|
||||
return self._values
|
||||
|
||||
def _get_as_operand(self):
|
||||
"""Returns the value for operations for the current device.
|
||||
|
||||
@ -272,6 +272,11 @@ class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
|
||||
return PerReplicaSpec(
|
||||
*(type_spec.type_spec_from_value(v) for v in self._values))
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
"""Returns the per replica values."""
|
||||
return self._values
|
||||
|
||||
|
||||
class PerReplicaSpec(type_spec.TypeSpec):
|
||||
"""Type specification for a `PerReplica`."""
|
||||
@ -824,7 +829,7 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
if update_replica_id is not None:
|
||||
# We are calling an assign function on the mirrored variable in an
|
||||
# update context.
|
||||
return f(self.values[update_replica_id], *args, **kwargs)
|
||||
return f(self._values[update_replica_id], *args, **kwargs)
|
||||
|
||||
# We are calling assign on the mirrored variable in cross replica
|
||||
# context, use `strategy.extended.update()` to update the variable.
|
||||
|
@ -715,9 +715,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
return t
|
||||
results = strat.extended.call_for_each_replica(
|
||||
fn=f, args=gens)
|
||||
values = results.values
|
||||
self.assertAllEqual(2, len(values))
|
||||
self.assertAllDifferent(values)
|
||||
local_results = strat.experimental_local_results(results)
|
||||
self.assertAllEqual(2, len(local_results))
|
||||
self.assertAllDifferent(local_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -75,7 +75,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def convert_tensor_to_list(tensor):
|
||||
if isinstance(tensor, values.DistributedValues):
|
||||
return tensor.values
|
||||
return strategy.experimental_local_results(tensor)
|
||||
else:
|
||||
return [tensor]
|
||||
return nest.map_structure(convert_tensor_to_list, results)
|
||||
|
Loading…
Reference in New Issue
Block a user