Support packed variable in DistributedVariable. Add an option to enable packed variable in TPUStrategy.

PiperOrigin-RevId: 317234665
Change-Id: I09e806cb8261815cd87a6d98817556dd8f7e8ed7
This commit is contained in:
Yujing Zhang 2020-06-18 20:03:31 -07:00 committed by TensorFlower Gardener
parent 4d54ef3139
commit 7e6e549c46
15 changed files with 454 additions and 270 deletions

View File

@ -654,6 +654,7 @@ tpu_py_test(
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:remote", "//tensorflow/python/eager:remote",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",
], ],
) )
@ -787,6 +788,7 @@ py_library(
name = "tpu_values", name = "tpu_values",
srcs = ["tpu_values.py"], srcs = ["tpu_values.py"],
deps = [ deps = [
":packed_distributed_variable",
":values", ":values",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -1602,7 +1604,7 @@ distribute_py_test(
srcs = ["saved_model_save_load_test.py"], srcs = ["saved_model_save_load_test.py"],
full_precision = True, full_precision = True,
main = "saved_model_save_load_test.py", main = "saved_model_save_load_test.py",
shard_count = 5, shard_count = 7,
tags = [ tags = [
"multi_and_single_gpu", "multi_and_single_gpu",
"no_rocm", "no_rocm",
@ -1635,7 +1637,7 @@ distribute_py_test(
srcs = ["saved_model_mixed_api_test.py"], srcs = ["saved_model_mixed_api_test.py"],
full_precision = True, full_precision = True,
main = "saved_model_mixed_api_test.py", main = "saved_model_mixed_api_test.py",
shard_count = 5, shard_count = 7,
tags = [ tags = [
"multi_and_single_gpu", "multi_and_single_gpu",
"no_rocm", "no_rocm",

View File

@ -103,6 +103,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=["eager"])) mode=["eager"]))
@ -138,6 +139,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=["eager"])) mode=["eager"]))

View File

@ -197,7 +197,8 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["eager"])) mode=["eager"]))
def testNestedOutput(self, distribution): def testNestedOutput(self, distribution):
@ -748,6 +749,10 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
mode=["eager"] mode=["eager"]
)) ))
def testMultiDeviceDataCapturedFunction(self, distribution): def testMultiDeviceDataCapturedFunction(self, distribution):
if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
self.skipTest(
"Dataset captured function doesn't support packed tensors yet "
"(b/145922293).")
inputs = constant_op.constant([2., 3.]) inputs = constant_op.constant([2., 3.])
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5) dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
input_iterator = iter( input_iterator = iter(

View File

@ -148,6 +148,9 @@ def select_replica_mirrored(replica_id, structured):
raise TypeError( raise TypeError(
"Expected value to be mirrored across replicas: %s in %s." % "Expected value to be mirrored across replicas: %s in %s." %
(x, structured)) (x, structured))
packed_var = getattr(x, "_packed_variable", None)
if packed_var is not None:
return packed_var
return x.values[replica_id] return x.values[replica_id]
else: else:
return x return x

View File

@ -42,7 +42,7 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
name: Optional name for the variable. Defaults to `'Variable'` and gets name: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically. uniquified automatically.
""" """
if not context.executing_eagerly(): if not ops.executing_eagerly_outside_functions():
raise ValueError( raise ValueError(
"PackedDistributedVariable should be created in eager mode.") "PackedDistributedVariable should be created in eager mode.")
if not distributed_variables: if not distributed_variables:
@ -84,6 +84,9 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
def devices(self): def devices(self):
return self._devices return self._devices
def on_device(self, device):
return PackedVarAndDevice(self, device)
def get_var_on_device(self, device): def get_var_on_device(self, device):
for i, d in enumerate(self._devices): for i, d in enumerate(self._devices):
if d == device: if d == device:
@ -100,6 +103,9 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
@property @property
def handle(self): def handle(self):
if context.executing_eagerly():
return self.get_var_on_current_device().handle
else:
return self._handle return self._handle
def _read_variable_op(self): def _read_variable_op(self):
@ -269,6 +275,7 @@ class PackedVarAndDevice(object):
@property @property
def handle(self): def handle(self):
with ops.device(self._device):
return self._var.handle return self._var.handle
@property @property

View File

@ -46,7 +46,7 @@ class PackedDistributedVariableTest(test.TestCase):
v1 = resource_variable_ops.ResourceVariable(2.0, name='var1') v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')
packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1]) packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
self.assertTrue(packed_var.handle.is_packed) self.assertFalse(packed_var.handle.is_packed)
self.assertTrue(packed_var.is_initialized) self.assertTrue(packed_var.is_initialized)
with ops.device('/cpu:0'): with ops.device('/cpu:0'):
@ -61,6 +61,7 @@ class PackedDistributedVariableTest(test.TestCase):
@def_function.function @def_function.function
def update_var(): def update_var():
self.assertTrue(packed_var.handle.is_packed)
with ops.device('/cpu:0'): with ops.device('/cpu:0'):
packed_var.assign_add(3.0).assign_sub(1.0) packed_var.assign_add(3.0).assign_sub(1.0)
read0 = packed_var.value() read0 = packed_var.value()
@ -85,7 +86,7 @@ class PackedDistributedVariableTest(test.TestCase):
packed_var0 = packed_distributed_variable.PackedVarAndDevice( packed_var0 = packed_distributed_variable.PackedVarAndDevice(
packed_var, device0) packed_var, device0)
self.assertTrue(packed_var0.handle.is_packed) self.assertFalse(packed_var0.handle.is_packed)
self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0) self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)
packed_var1 = packed_distributed_variable.PackedVarAndDevice( packed_var1 = packed_distributed_variable.PackedVarAndDevice(
@ -94,6 +95,7 @@ class PackedDistributedVariableTest(test.TestCase):
@def_function.function @def_function.function
def func(): def func():
self.assertTrue(packed_var.handle.is_packed)
var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0) var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0)
var0.assign_add(3.0) var0.assign_add(3.0)
var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1) var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1)

View File

@ -58,6 +58,7 @@ strategies = [
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus, strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
] ]

View File

@ -53,7 +53,11 @@ _did_connect_to_cluster = False
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs): def _get_tpu_strategy_creator(steps_per_run,
use_single_core=False,
enable_packed_variable=False,
**kwargs):
def _create_tpu_strategy(): def _create_tpu_strategy():
global _did_connect_to_cluster global _did_connect_to_cluster
@ -87,10 +91,13 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
# Steps per run is only supported in TF 1.x # Steps per run is only supported in TF 1.x
if tf2.enabled(): if tf2.enabled():
return tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs) strategy = tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
else: else:
return tpu_lib.TPUStrategyV1(resolver, steps_per_run, strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
device_assignment, **kwargs) device_assignment, **kwargs)
strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access
return strategy
return _create_tpu_strategy return _create_tpu_strategy
@ -117,6 +124,10 @@ one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
required_gpus=1) required_gpus=1)
tpu_strategy = combinations.NamedDistribution( tpu_strategy = combinations.NamedDistribution(
"TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
tpu_strategy_packed_var = combinations.NamedDistribution(
"TPUPackedVar",
_get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True),
required_tpu=True)
tpu_strategy_one_step = combinations.NamedDistribution( tpu_strategy_one_step = combinations.NamedDistribution(
"TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True) "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
tpu_strategy_one_core = combinations.NamedDistribution( tpu_strategy_one_core = combinations.NamedDistribution(
@ -286,6 +297,7 @@ strategies_minus_default_and_tpu = [
tpu_strategies = [ tpu_strategies = [
tpu_strategy, # steps_per_run=2 tpu_strategy, # steps_per_run=2
tpu_strategy_one_step, tpu_strategy_one_step,
tpu_strategy_packed_var,
cloud_tpu_strategy, cloud_tpu_strategy,
] ]

View File

@ -141,6 +141,10 @@ class TPUStrategy(distribute_lib.Strategy):
"num_workers").set(self.extended.num_hosts) "num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell( distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host) "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = False
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation. # can use the default implementation.
@ -185,6 +189,10 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
"num_workers").set(self.extended.num_hosts) "num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell( distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host) "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = False
@property @property
def steps_per_run(self): def steps_per_run(self):
@ -671,20 +679,29 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return cross_device_ops_lib.reduce_non_distributed_value( return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync) reduce_op, value, destinations, self._num_replicas_in_sync)
value_list = value.values
# pylint: disable=protected-access
if isinstance(
value,
values.DistributedVariable) and value._packed_variable is not None:
value_list = tuple(
value._packed_variable.on_device(d)
for d in value._packed_variable.devices)
# pylint: enable=protected-access
# Currently XLA op by op mode has a limit for the number of inputs for a # Currently XLA op by op mode has a limit for the number of inputs for a
# single op, thus we break one `add_n` op into a group of `add_n` ops to # single op, thus we break one `add_n` op into a group of `add_n` ops to
# work around the constraint. # work around the constraint.
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
output = math_ops.add_n(value.values) output = math_ops.add_n(value_list)
else: else:
output = array_ops.zeros_like( output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
value.values[0], dtype=value.values[0].dtype) for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
for i in range(0, len(value.values), _XLA_OP_BY_OP_INPUTS_LIMIT): output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
output += math_ops.add_n(value.values[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
if reduce_op == reduce_util.ReduceOp.MEAN: if reduce_op == reduce_util.ReduceOp.MEAN:
output *= (1. / len(value.values)) output *= (1. / len(value_list))
devices = cross_device_ops_lib.get_devices_from(destinations) devices = cross_device_ops_lib.get_devices_from(destinations)
@ -710,17 +727,28 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
else: else:
return (fn(var, *args, **kwargs),) return (fn(var, *args, **kwargs),)
# Otherwise, we revert to MirroredStrategy behavior and update each variable # Otherwise, we revert to MirroredStrategy behavior and update the variable
# directly. # on each replica directly.
updates = [] updates = []
for i, v in enumerate(var.values): values_and_devices = []
packed_var = var._packed_variable # pylint: disable=protected-access
if packed_var is not None:
for device in packed_var.devices:
values_and_devices.append((packed_var, device))
else:
for value in var.values:
values_and_devices.append((value, value.device))
for i, value_and_device in enumerate(values_and_devices):
value = value_and_device[0]
device = value_and_device[1]
name = "update_%d" % i name = "update_%d" % i
with ops.device(v.device), \ with ops.device(device), \
distribute_lib.UpdateContext(i), \ distribute_lib.UpdateContext(i), \
ops.name_scope(name): ops.name_scope(name):
# If args and kwargs are not mirrored, the value is returned as is. # If args and kwargs are not mirrored, the value is returned as is.
updates.append( updates.append(
fn(v, *distribute_utils.select_replica_mirrored(i, args), fn(value, *distribute_utils.select_replica_mirrored(i, args),
**distribute_utils.select_replica_mirrored(i, kwargs))) **distribute_utils.select_replica_mirrored(i, kwargs)))
return distribute_utils.update_regroup(self, updates, group) return distribute_utils.update_regroup(self, updates, group)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
@ -64,14 +66,17 @@ def get_tpu_cluster_resolver():
return resolver return resolver
def get_tpu_strategy(): def get_tpu_strategy(enable_packed_var=False):
resolver = get_tpu_cluster_resolver() resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver) remote.connect_to_cluster(resolver)
tpu_strategy_util.initialize_tpu_system(resolver) tpu_strategy_util.initialize_tpu_system(resolver)
return tpu_lib.TPUStrategy(resolver) strategy = tpu_lib.TPUStrategy(resolver)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
return strategy
class TPUStrategyTest(test.TestCase): # TPU tests which don't use TPUStrategy.
class TPUTest(test.TestCase):
def test_multiple_initialize_system(self): def test_multiple_initialize_system(self):
resolver = get_tpu_cluster_resolver() resolver = get_tpu_cluster_resolver()
@ -82,177 +87,6 @@ class TPUStrategyTest(test.TestCase):
tpu_strategy_util.initialize_tpu_system(resolver) tpu_strategy_util.initialize_tpu_system(resolver)
self.assertRegex(str(mock_log.call_args), "already been initialized") self.assertRegex(str(mock_log.call_args), "already been initialized")
def test_sequential_experimental_runs(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Computation replicated to all cores.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=2)
strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
# Computation on the 1st core.
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
strategy2 = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([2., 2.],)))
outputs2 = strategy2.run(
computation, args=([outputs[0]],))
return outputs2
self.assertAllEqual([[16., 16.]], train_step())
def test_device_switch_case(self):
strategy = get_tpu_strategy()
with strategy.scope():
a = variables.Variable(1)
inference_iteration = variables.Variable(-1)
def inference_fn(x, i):
return a + x + i
@def_function.function
def run_inference(x):
def do_inference(device, inference_fn, i):
with ops.device(device):
return inference_fn(x, i)
branch_fns = {
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
}
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
return control_flow_ops.switch_case(branch_index, branch_fns)
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
def test_recover_from_compilation_failures(self):
# TODO(b/148150981): Stop skipping this test once recovery works
# for non-local TPU.
if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy()
@def_function.function
def compilation_failure_run():
def computation():
return random_ops.random_gamma([10], [0.5, 1.5])
return strategy.run(computation)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"TPU compilation failed"):
compilation_failure_run()
@def_function.function
def good_run():
def computation():
return random_ops.random_normal([10])
return strategy.run(computation)
good_run()
def test_dynamic_shape_with_outside_compilation_failure(self):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
_, inputs = inputs
return math_ops.reduce_sum(inputs)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
all_core_strategy = tpu_lib.TPUStrategy(resolver)
with all_core_strategy.scope():
v = variables.Variable(0.0,
aggregation=variables.VariableAggregation.MEAN)
# Computation on the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
# Computation on the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
@def_function.function
def train_step():
def step_fn():
return v + 1.0
all_core_strategy.run(step_fn)
r1 = first_core_strategy.run(step_fn)
r2 = second_core_strategy.run(step_fn)
return r1 + r2
train_step()
self.assertAllEqual(2., train_step())
def test_worker_devices_on_subset_cores(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Strategy for the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
# Strategy for the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
self.assertLen(first_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
"device:TPU:0")
self.assertLen(second_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
"device:TPU:1")
def test_tpu_tf_function_same_device(self): def test_tpu_tf_function_same_device(self):
with ops.device("/device:TPU:0"): with ops.device("/device:TPU:0"):
a = variables.Variable(1) a = variables.Variable(1)
@ -288,8 +122,194 @@ class TPUStrategyTest(test.TestCase):
result = bar() + 1 result = bar() + 1
self.assertAllEqual(result, 2) self.assertAllEqual(result, 2)
def test_control_output_in_while_body_fn(self):
strategy = get_tpu_strategy() @parameterized.named_parameters([("PackedVar", True), ("", False)])
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
def test_sequential_experimental_runs(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Computation replicated to all cores.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=2)
strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
# Computation on the 1st core.
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
strategy2 = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([2., 2.],)))
outputs2 = strategy2.run(
computation, args=([outputs[0]],))
return outputs2
self.assertAllEqual([[16., 16.]], train_step())
def test_device_switch_case(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
a = variables.Variable(1)
inference_iteration = variables.Variable(-1)
def inference_fn(x, i):
return a + x + i
@def_function.function
def run_inference(x):
def do_inference(device, inference_fn, i):
with ops.device(device):
return inference_fn(x, i)
branch_fns = {
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
}
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
return control_flow_ops.switch_case(branch_index, branch_fns)
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
def test_recover_from_compilation_failures(self, enable_packed_var):
# TODO(b/148150981): Stop skipping this test once recovery works
# for non-local TPU.
if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def compilation_failure_run():
def computation():
return random_ops.random_gamma([10], [0.5, 1.5])
return strategy.run(computation)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"TPU compilation failed"):
compilation_failure_run()
@def_function.function
def good_run():
def computation():
return random_ops.random_normal([10])
return strategy.run(computation)
good_run()
def test_dynamic_shape_with_outside_compilation_failure(
self, enable_packed_var):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
_, inputs = inputs
return math_ops.reduce_sum(inputs)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
all_core_strategy = tpu_lib.TPUStrategy(resolver)
all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
with all_core_strategy.scope():
v = variables.Variable(0.0,
aggregation=variables.VariableAggregation.MEAN)
# Computation on the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Computation on the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
@def_function.function
def train_step():
def step_fn():
return v + 1.0
all_core_strategy.run(step_fn)
r1 = first_core_strategy.run(step_fn)
r2 = second_core_strategy.run(step_fn)
return r1 + r2
train_step()
self.assertAllEqual(2., train_step())
def test_worker_devices_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Strategy for the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Strategy for the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
self.assertLen(first_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
"device:TPU:0")
self.assertLen(second_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
"device:TPU:1")
def test_control_output_in_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope(): with strategy.scope():
v = variables.Variable( v = variables.Variable(
@ -307,8 +327,8 @@ class TPUStrategyTest(test.TestCase):
train_step() train_step()
self.assertEqual(2.0, v.numpy()) self.assertEqual(2.0, v.numpy())
def test_cluster_in_graph_and_while_body_fn(self): def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy() strategy = get_tpu_strategy(enable_packed_var)
@def_function.function @def_function.function
def train_step(): def train_step():
@ -328,8 +348,8 @@ class TPUStrategyTest(test.TestCase):
sum_val = train_step().numpy().astype(float) sum_val = train_step().numpy().astype(float)
self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10) self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
def test_two_clusters_with_same_fn(self): def test_two_clusters_with_same_fn(self, enable_packed_var):
strategy = get_tpu_strategy() strategy = get_tpu_strategy(enable_packed_var)
@def_function.function @def_function.function
def foo(x): def foo(x):
@ -342,8 +362,8 @@ class TPUStrategyTest(test.TestCase):
bar(1) bar(1)
def test_using_external_variable_inside_tf_function(self): def test_using_external_variable_inside_tf_function(self, enable_packed_var):
strategy = get_tpu_strategy() strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range( dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2, strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync) output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
@ -364,8 +384,8 @@ class TPUStrategyTest(test.TestCase):
strategy.experimental_local_results(train_step(next(input_iterator)))) strategy.experimental_local_results(train_step(next(input_iterator))))
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py. # TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
def test_all_reduce_on_sync_on_read_variable(self): def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var):
strategy = get_tpu_strategy() strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range( dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync, output_type=dtypes.float32).batch( strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
strategy.num_replicas_in_sync, drop_remainder=True) strategy.num_replicas_in_sync, drop_remainder=True)
@ -404,8 +424,8 @@ class TPUStrategyTest(test.TestCase):
self.assertAllEqual((0.,), w.read_value()) self.assertAllEqual((0.,), w.read_value())
# TODO(b/140633529): Re-enable the test. # TODO(b/140633529): Re-enable the test.
def disable_test_experimental_run_output_on_device(self): def disable_test_experimental_run_output_on_device(self, enable_packed_var):
strategy = get_tpu_strategy() strategy = get_tpu_strategy(enable_packed_var)
def computation(x): def computation(x):
return math_ops.square(x) return math_ops.square(x)
@ -423,8 +443,8 @@ class TPUStrategyTest(test.TestCase):
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1", self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
results[1].backing_device) results[1].backing_device)
def test_composite_input(self): def test_composite_input(self, enable_packed_var):
strategy = get_tpu_strategy() strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2: if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.") self.skipTest("Test assumes two replicas.")
@ -463,8 +483,9 @@ class TPUStrategyTest(test.TestCase):
self.assertAllEqual(result, self.assertAllEqual(result,
[[[0.0, 1.0], [3.0, 8.0]], [[0.0, 1.0], [3.0, 8.0]]]) [[[0.0, 1.0], [3.0, 8.0]], [[0.0, 1.0], [3.0, 8.0]]])
def test_composite_input_dynamic_shapes_outside_compilation(self): def test_composite_input_dynamic_shapes_outside_compilation(
strategy = get_tpu_strategy() self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2: if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.") self.skipTest("Test assumes two replicas.")
@ -506,11 +527,11 @@ class TPUStrategyTest(test.TestCase):
result = sparse_lookup(dataset) result = sparse_lookup(dataset)
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]]) self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
def test_per_device_tracing_of_mirrored_variables(self): def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var):
# Define trace_count as a list to avoid python scoping error # Define trace_count as a list to avoid python scoping error
trace_count = [0] trace_count = [0]
strategy = get_tpu_strategy() strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope(): with strategy.scope():
variable = variables.Variable(0.0) variable = variables.Variable(0.0)
@ -527,7 +548,10 @@ class TPUStrategyTest(test.TestCase):
with strategy.scope(): with strategy.scope():
update_variable.get_concrete_function() update_variable.get_concrete_function()
self.assertEqual(trace_count[0], len(strategy.extended.worker_devices)) self.assertLen(strategy.extended.worker_devices, trace_count[0])
class TPUStrategyDataPrefetchTest(test.TestCase):
def test_prefetch_to_device_default(self): def test_prefetch_to_device_default(self):
strategy = get_tpu_strategy() strategy = get_tpu_strategy()

View File

@ -24,6 +24,7 @@ from __future__ import print_function
import contextlib import contextlib
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
@ -46,15 +47,27 @@ def _maybe_enter_graph(tensor):
yield yield
@contextlib.contextmanager
def _maybe_on_device(var):
# Add a device scope for packed variables.
if isinstance(var, packed.PackedVarAndDevice):
with ops.device(var.device):
yield
else:
yield
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
del use_locking # Unused. del use_locking # Unused.
with _maybe_enter_graph(var.handle): handle = var.handle
with _maybe_enter_graph(handle), _maybe_on_device(var):
op = raw_assign_fn( op = raw_assign_fn(
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) handle,
ops.convert_to_tensor(value, dtype=var.dtype),
name=name)
with ops.control_dependencies([op]): with ops.control_dependencies([op]):
return var._read_variable_op() if read_value else op # pylint: disable=protected-access return var._read_variable_op() if read_value else op # pylint: disable=protected-access
@ -97,23 +110,37 @@ class TPUVariableMixin(object):
@property @property
def handle(self): def handle(self):
"""The handle by which this variable can be accessed."""
# If we're in a tpu.rewrite(), return the replicated handle. # If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = enclosing_tpu_context() tpu_context = enclosing_tpu_context()
if tpu_context is None or context.executing_eagerly(): if tpu_context is None or context.executing_eagerly():
return self._get_on_device_or_primary().handle return self._get_on_device_or_primary().handle
else: else:
return tpu_context.get_replicated_var_handle(self._handle_id, is_packed = self._packed_var is not None
self._values, val = self._values
self._is_mirrored()) if is_packed:
val = [self._packed_var]
return tpu_context.get_replicated_var_handle(self._handle_id, val,
self._is_mirrored(),
is_packed)
@property @property
def device(self): def device(self):
return self.handle.device return self.handle.device
def _read_variable_op(self): def _read_variable_op(self):
"""Reads the value of this variable."""
if self.trainable: if self.trainable:
tape.variable_accessed(self) tape.variable_accessed(self)
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
handle = self.handle
if getattr(handle, "is_packed", False):
# Add a device scope for a packed variable handle.
with ops.device(self._get_on_device_or_primary().device):
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
else:
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
def read_value(self): def read_value(self):
if enclosing_tpu_context() is None: if enclosing_tpu_context() is None:

View File

@ -472,6 +472,12 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
# variable. # variable.
self._var_policy = var_policy self._var_policy = var_policy
@property
def _devices(self):
if self._packed_var is not None:
return tuple(d for d in self._packed_var.devices)
return tuple(v.device for v in self._values)
def is_initialized(self, name=None): def is_initialized(self, name=None):
"""Identifies if all the component variables are initialized. """Identifies if all the component variables are initialized.
@ -482,6 +488,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
The op that evaluates to True or False depending on if all the The op that evaluates to True or False depending on if all the
component variables are initialized. component variables are initialized.
""" """
if self._packed_var is not None:
return self._packed_var.is_initialized()
result = self._primary.is_initialized() result = self._primary.is_initialized()
# We iterate through the list of values except the last one to allow us to # We iterate through the list of values except the last one to allow us to
# name the final `logical_and` op the same name that is passed by the user # name the final `logical_and` op the same name that is passed by the user
@ -552,6 +560,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
def aggregation(self): def aggregation(self):
return self._aggregation return self._aggregation
@property
def _packed_variable(self):
return self._packed_var
@property @property
def handle(self): def handle(self):
replica_id = values_util.get_current_replica_id_as_int() replica_id = values_util.get_current_replica_id_as_int()
@ -559,6 +571,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
raise ValueError("`handle` is not available outside the replica context" raise ValueError("`handle` is not available outside the replica context"
" or a `tf.distribute.Strategy.update()` call.") " or a `tf.distribute.Strategy.update()` call.")
else: else:
if self._packed_var is not None:
return self._packed_var.handle
return self._values[replica_id].handle return self._values[replica_id].handle
def eval(self, session=None): def eval(self, session=None):
@ -607,6 +621,33 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
def _in_graph_mode(self): def _in_graph_mode(self):
return self._primary._in_graph_mode # pylint: disable=protected-access return self._primary._in_graph_mode # pylint: disable=protected-access
def _get_replica(self, replica_id):
"""Returns the value on a device with the given replica_id."""
if self._packed_var is not None:
return self._packed_var.on_device(self._devices[replica_id])
return self._values[replica_id]
def _get(self):
"""Returns the value for the current device or raises a ValueError."""
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
return self._get_cross_replica()
else:
return self._get_replica(replica_id)
def _get_on_device_or_primary(self):
"""Returns value in same replica or device if possible, else the _primary."""
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
# Try to find a value on the current device.
current_device = device_util.canonicalize(device_util.current())
for i, value in enumerate(self._values):
if device_util.canonicalize(value.device) == current_device:
return self._get_replica(i)
return self._get_replica(0)
else:
return self._get_replica(replica_id)
def read_value(self): def read_value(self):
with ds_context.enter_or_assert_strategy(self._distribute_strategy): with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return array_ops.identity(self._get()) return array_ops.identity(self._get())
@ -778,7 +819,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
if ds_context.in_cross_replica_context(): if ds_context.in_cross_replica_context():
update_replica_id = distribute_lib.get_update_replica_id() update_replica_id = distribute_lib.get_update_replica_id()
if update_replica_id is not None: if update_replica_id is not None:
return update_fn(self._values[update_replica_id], value, **kwargs) replica_value = self._get_replica(update_replica_id)
return update_fn(replica_value, value, **kwargs)
return self._update_cross_replica(update_fn, value, **kwargs) return self._update_cross_replica(update_fn, value, **kwargs)
else: else:
values_util.assert_replica_context(self.distribute_strategy) values_util.assert_replica_context(self.distribute_strategy)
@ -802,6 +844,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
obj_map[v] = new_obj obj_map[v] = new_obj
resource_map[v.handle] = new_obj.handle resource_map[v.handle] = new_obj.handle
obj_map[self] = new_obj obj_map[self] = new_obj
resource_map[self.handle] = new_obj.handle
resource_map[self] = new_obj.handle resource_map[self] = new_obj.handle
return obj_map, resource_map return obj_map, resource_map
@ -835,6 +878,12 @@ class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
def restore(self, restored_tensors, restored_shapes): def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables.""" """Restore the same value into all variables."""
tensor, = restored_tensors tensor, = restored_tensors
packed_var = self._mirrored_variable._packed_variable # pylint: disable=protected-access
if packed_var is not None:
return control_flow_ops.group(
tuple(
values_util.assign_on_device(d, packed_var, tensor)
for d in packed_var.devices))
return control_flow_ops.group( return control_flow_ops.group(
tuple( tuple(
values_util.assign_on_device(v.device, v, tensor) values_util.assign_on_device(v.device, v, tensor)
@ -1013,7 +1062,7 @@ class SyncOnReadVariable(DistributedVariable):
def _get_cross_replica(self): def _get_cross_replica(self):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return self._primary return self._get_replica(0)
with ds_context.enter_or_assert_strategy(self._distribute_strategy): with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return self._distribute_strategy.reduce( return self._distribute_strategy.reduce(

View File

@ -42,7 +42,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -234,11 +233,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
# TODO(b/137795644): support CentralStroageStrategy # TODO(b/137795644): support CentralStroageStrategy
# strategy_combinations.central_storage_strategy_with_two_gpus, # strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=["eager"] mode=["eager"]))
))
def testMakeDistributedValueDefaultDevicePlacement(self, distribution): def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
if not tf2.enabled(): if not tf2.enabled():
self.skipTest("Only V2 is supported.") self.skipTest("Only V2 is supported.")
@ -259,11 +258,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
# TODO(b/137795644): support CentralStroageStrategy # TODO(b/137795644): support CentralStroageStrategy
# strategy_combinations.central_storage_strategy_with_two_gpus, # strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=["eager"] mode=["eager"]))
))
def testMakeDistributedValueExplicitDevicePlacement(self, distribution): def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
if not tf2.enabled(): if not tf2.enabled():
self.skipTest("Only V2 is supported.") self.skipTest("Only V2 is supported.")
@ -384,6 +383,16 @@ def _make_mirrored():
return mirrored return mirrored
def mirrored_and_tpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"])
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
def _is_per_replica(self, result, expected, klass=values.PerReplica): def _is_per_replica(self, result, expected, klass=values.PerReplica):
@ -563,6 +572,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
synchronization=[ synchronization=[
@ -708,29 +718,40 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
self.evaluate( self.evaluate(
distribution.experimental_local_results(distribution.run(assign))) distribution.experimental_local_results(distribution.run(assign)))
def testPackedVariable(self, distribution, synchronization, aggregation):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.tpu_strategy,
],
mode=["eager"]))
class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
def testPackedVariable(self, distribution):
with distribution.scope(): with distribution.scope():
v0 = variables_lib.Variable( v0 = variables_lib.Variable(0.)
0., synchronization=synchronization, aggregation=aggregation) self.assertIsNone(v0._packed_var)
if not isinstance(v0, values.DistributedVariable):
self.skipTest("This test doesn't apply to non DistributedVariables")
self.assertEqual(v0._packed_var, None)
device_type = device.DeviceSpec.from_string(v0._devices[0]).device_type
for d in v0._devices:
if device.DeviceSpec.from_string(d).device_type != device_type:
self.skipTest("Packing variables on devices of different types "
"is not supported yet.")
distribution._enable_packed_variable_in_eager_mode = True distribution._enable_packed_variable_in_eager_mode = True
with distribution.scope(): with distribution.scope():
v1 = variables_lib.Variable( v1 = variables_lib.Variable(0)
0., synchronization=synchronization, aggregation=aggregation)
if ops.executing_eagerly_outside_functions():
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable) self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
else:
self.assertEqual(v1._packed_var, None) devices = v1._devices
for i in range(1, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
v1.assign(i)
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[0])
self.assertEqual(self.evaluate(val.read_value()), 0)
for i in range(0, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[i])
self.assertEqual(self.evaluate(val.read_value()), i)
class MirroredVariableTest(test.TestCase, parameterized.TestCase): class MirroredVariableTest(test.TestCase, parameterized.TestCase):
@ -920,6 +941,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["eager"])) mode=["eager"]))
def testAssignValueInReplicaContextWithoutAggregation(self, distribution): def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
@ -943,6 +965,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["graph", "eager"])) mode=["graph", "eager"]))
def testValueInReplicaContext(self, distribution): def testValueInReplicaContext(self, distribution):
@ -968,6 +991,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["graph", "eager"])) mode=["graph", "eager"]))
def testAssignOutOfScope(self, distribution): def testAssignOutOfScope(self, distribution):
@ -1041,6 +1065,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["eager"])) mode=["eager"]))
def testInitializedToSameValueInsideEagerRun(self, distribution): def testInitializedToSameValueInsideEagerRun(self, distribution):
@ -1066,6 +1091,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["graph", "eager"])) mode=["graph", "eager"]))
def testAggregationOnlyFirstReplica(self, distribution): def testAggregationOnlyFirstReplica(self, distribution):
@ -1093,6 +1119,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["eager"])) mode=["eager"]))
def testInitScope(self, distribution): def testInitScope(self, distribution):
@ -1143,13 +1170,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution.experimental_local_results(distribution.run(add))) distribution.experimental_local_results(distribution.run(add)))
self.assertAllEqual([2, 2], per_replica_results) self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate( @combinations.generate(mirrored_and_tpu_strategy_combinations())
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"]))
def testAssignAdd(self, distribution): def testAssignAdd(self, distribution):
with distribution.scope(): with distribution.scope():
v = variable_scope.variable( v = variable_scope.variable(
@ -1456,15 +1477,6 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
self.assertEqual(2., self.evaluate(add1(replica_local))) self.assertEqual(2., self.evaluate(add1(replica_local)))
def mirrored_and_tpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"])
# TODO(b/144432582): Add variable aggregation type to combinations to simplify # TODO(b/144432582): Add variable aggregation type to combinations to simplify
# tests. # tests.
def strategy_and_run_tf_function_combinations(): def strategy_and_run_tf_function_combinations():
@ -1478,6 +1490,7 @@ def strategy_and_run_tf_function_combinations():
experimental_run_tf_function=[True, False]) + combinations.combine( experimental_run_tf_function=[True, False]) + combinations.combine(
distribution=[ distribution=[
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
], ],
mode=["graph", "eager"], mode=["graph", "eager"],
experimental_run_tf_function=[True]) experimental_run_tf_function=[True])

View File

@ -61,6 +61,12 @@ def on_write_assign_sub(var, value, use_locking=False, name=None,
def assign_on_each_device(var, assign_func, value, read_value): def assign_on_each_device(var, assign_func, value, read_value):
"""Update the variable on each replica with the given assign_func and value."""
if var._packed_variable is not None: # pylint: disable=protected-access
update = control_flow_ops.group(
tuple(
assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access
else:
update = control_flow_ops.group( update = control_flow_ops.group(
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
if not read_value: if not read_value:
@ -104,7 +110,7 @@ def on_read_assign_cross_replica(var, value, read_value=True):
# TODO(anjs): Should this be over all the replicas in sync since we # TODO(anjs): Should this be over all the replicas in sync since we
# call `reduce` on the variable during read? # call `reduce` on the variable during read?
if var.aggregation == vs.VariableAggregation.SUM: if var.aggregation == vs.VariableAggregation.SUM:
tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access tensor = math_ops.cast(tensor / len(var._devices), var.dtype) # pylint: disable=protected-access
return assign_on_each_device(var, assign_on_device, tensor, return assign_on_each_device(var, assign_on_device, tensor,
read_value) read_value)

View File

@ -298,7 +298,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot self._pivot = pivot
self._replicated_vars = {} self._replicated_vars = {}
def get_replicated_var_handle(self, name, vars_, is_mirrored=False): def get_replicated_var_handle(self, name, vars_, is_mirrored=False,
is_packed=False):
"""Returns a variable handle for replicated TPU variable 'var'. """Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation This is a method used by an experimental replicated variable implementation
@ -309,6 +310,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
vars_: The replicated TPU variables. vars_: The replicated TPU variables.
is_mirrored: Whether the variables are mirrored, which guarantees the is_mirrored: Whether the variables are mirrored, which guarantees the
values in each replica are always the same. values in each replica are always the same.
is_packed: Whether the replicated variables are packed into one variable.
Returns: Returns:
The handle of the TPU replicated input node. The handle of the TPU replicated input node.
@ -320,7 +322,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if handle is not None: if handle is not None:
return handle return handle
if device_assignment is not None: if device_assignment is not None and not is_packed:
# Find a variable copy for each replica in the device assignment. # Find a variable copy for each replica in the device assignment.
# Note that the order of devices for replicas for the variable and the # Note that the order of devices for replicas for the variable and the
# device assignment might not match. # device assignment might not match.
@ -356,7 +358,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
graph._set_control_flow_context(self.outer_context) graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars], handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
name=name + "/handle", name=name + "/handle",
is_mirrored_variable=is_mirrored) is_mirrored_variable=is_mirrored,
is_packed=is_packed)
graph._set_control_flow_context(saved_context) graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access # pylint: enable=protected-access
self._replicated_vars[name] = handle self._replicated_vars[name] = handle