Support packed variable in DistributedVariable. Add an option to enable packed variable in TPUStrategy.

PiperOrigin-RevId: 317234665
Change-Id: I09e806cb8261815cd87a6d98817556dd8f7e8ed7
This commit is contained in:
Yujing Zhang 2020-06-18 20:03:31 -07:00 committed by TensorFlower Gardener
parent 4d54ef3139
commit 7e6e549c46
15 changed files with 454 additions and 270 deletions

View File

@ -654,6 +654,7 @@ tpu_py_test(
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:remote",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",
],
)
@ -787,6 +788,7 @@ py_library(
name = "tpu_values",
srcs = ["tpu_values.py"],
deps = [
":packed_distributed_variable",
":values",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
@ -1602,7 +1604,7 @@ distribute_py_test(
srcs = ["saved_model_save_load_test.py"],
full_precision = True,
main = "saved_model_save_load_test.py",
shard_count = 5,
shard_count = 7,
tags = [
"multi_and_single_gpu",
"no_rocm",
@ -1635,7 +1637,7 @@ distribute_py_test(
srcs = ["saved_model_mixed_api_test.py"],
full_precision = True,
main = "saved_model_mixed_api_test.py",
shard_count = 5,
shard_count = 7,
tags = [
"multi_and_single_gpu",
"no_rocm",

View File

@ -103,6 +103,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]))
@ -138,6 +139,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]))

View File

@ -197,7 +197,8 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["eager"]))
def testNestedOutput(self, distribution):
@ -748,6 +749,10 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
mode=["eager"]
))
def testMultiDeviceDataCapturedFunction(self, distribution):
if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
self.skipTest(
"Dataset captured function doesn't support packed tensors yet "
"(b/145922293).")
inputs = constant_op.constant([2., 3.])
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
input_iterator = iter(

View File

@ -148,6 +148,9 @@ def select_replica_mirrored(replica_id, structured):
raise TypeError(
"Expected value to be mirrored across replicas: %s in %s." %
(x, structured))
packed_var = getattr(x, "_packed_variable", None)
if packed_var is not None:
return packed_var
return x.values[replica_id]
else:
return x

View File

@ -42,7 +42,7 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
name: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
"""
if not context.executing_eagerly():
if not ops.executing_eagerly_outside_functions():
raise ValueError(
"PackedDistributedVariable should be created in eager mode.")
if not distributed_variables:
@ -84,6 +84,9 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
def devices(self):
return self._devices
def on_device(self, device):
return PackedVarAndDevice(self, device)
def get_var_on_device(self, device):
for i, d in enumerate(self._devices):
if d == device:
@ -100,7 +103,10 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
@property
def handle(self):
return self._handle
if context.executing_eagerly():
return self.get_var_on_current_device().handle
else:
return self._handle
def _read_variable_op(self):
if context.executing_eagerly():
@ -269,7 +275,8 @@ class PackedVarAndDevice(object):
@property
def handle(self):
return self._var.handle
with ops.device(self._device):
return self._var.handle
@property
def op(self):

View File

@ -46,7 +46,7 @@ class PackedDistributedVariableTest(test.TestCase):
v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')
packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
self.assertTrue(packed_var.handle.is_packed)
self.assertFalse(packed_var.handle.is_packed)
self.assertTrue(packed_var.is_initialized)
with ops.device('/cpu:0'):
@ -61,6 +61,7 @@ class PackedDistributedVariableTest(test.TestCase):
@def_function.function
def update_var():
self.assertTrue(packed_var.handle.is_packed)
with ops.device('/cpu:0'):
packed_var.assign_add(3.0).assign_sub(1.0)
read0 = packed_var.value()
@ -85,7 +86,7 @@ class PackedDistributedVariableTest(test.TestCase):
packed_var0 = packed_distributed_variable.PackedVarAndDevice(
packed_var, device0)
self.assertTrue(packed_var0.handle.is_packed)
self.assertFalse(packed_var0.handle.is_packed)
self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)
packed_var1 = packed_distributed_variable.PackedVarAndDevice(
@ -94,6 +95,7 @@ class PackedDistributedVariableTest(test.TestCase):
@def_function.function
def func():
self.assertTrue(packed_var.handle.is_packed)
var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0)
var0.assign_add(3.0)
var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1)

View File

@ -58,6 +58,7 @@ strategies = [
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus,
]

View File

@ -53,7 +53,11 @@ _did_connect_to_cluster = False
# pylint: disable=missing-docstring
def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
def _get_tpu_strategy_creator(steps_per_run,
use_single_core=False,
enable_packed_variable=False,
**kwargs):
def _create_tpu_strategy():
global _did_connect_to_cluster
@ -87,10 +91,13 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
# Steps per run is only supported in TF 1.x
if tf2.enabled():
return tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
strategy = tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
else:
return tpu_lib.TPUStrategyV1(resolver, steps_per_run,
device_assignment, **kwargs)
strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
device_assignment, **kwargs)
strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access
return strategy
return _create_tpu_strategy
@ -117,6 +124,10 @@ one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
required_gpus=1)
tpu_strategy = combinations.NamedDistribution(
"TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
tpu_strategy_packed_var = combinations.NamedDistribution(
"TPUPackedVar",
_get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True),
required_tpu=True)
tpu_strategy_one_step = combinations.NamedDistribution(
"TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
tpu_strategy_one_core = combinations.NamedDistribution(
@ -286,6 +297,7 @@ strategies_minus_default_and_tpu = [
tpu_strategies = [
tpu_strategy, # steps_per_run=2
tpu_strategy_one_step,
tpu_strategy_packed_var,
cloud_tpu_strategy,
]

View File

@ -141,6 +141,10 @@ class TPUStrategy(distribute_lib.Strategy):
"num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = False
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation.
@ -185,6 +189,10 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
"num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = False
@property
def steps_per_run(self):
@ -671,20 +679,29 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync)
value_list = value.values
# pylint: disable=protected-access
if isinstance(
value,
values.DistributedVariable) and value._packed_variable is not None:
value_list = tuple(
value._packed_variable.on_device(d)
for d in value._packed_variable.devices)
# pylint: enable=protected-access
# Currently XLA op by op mode has a limit for the number of inputs for a
# single op, thus we break one `add_n` op into a group of `add_n` ops to
# work around the constraint.
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
output = math_ops.add_n(value.values)
output = math_ops.add_n(value_list)
else:
output = array_ops.zeros_like(
value.values[0], dtype=value.values[0].dtype)
for i in range(0, len(value.values), _XLA_OP_BY_OP_INPUTS_LIMIT):
output += math_ops.add_n(value.values[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
if reduce_op == reduce_util.ReduceOp.MEAN:
output *= (1. / len(value.values))
output *= (1. / len(value_list))
devices = cross_device_ops_lib.get_devices_from(destinations)
@ -710,17 +727,28 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
else:
return (fn(var, *args, **kwargs),)
# Otherwise, we revert to MirroredStrategy behavior and update each variable
# directly.
# Otherwise, we revert to MirroredStrategy behavior and update the variable
# on each replica directly.
updates = []
for i, v in enumerate(var.values):
values_and_devices = []
packed_var = var._packed_variable # pylint: disable=protected-access
if packed_var is not None:
for device in packed_var.devices:
values_and_devices.append((packed_var, device))
else:
for value in var.values:
values_and_devices.append((value, value.device))
for i, value_and_device in enumerate(values_and_devices):
value = value_and_device[0]
device = value_and_device[1]
name = "update_%d" % i
with ops.device(v.device), \
with ops.device(device), \
distribute_lib.UpdateContext(i), \
ops.name_scope(name):
# If args and kwargs are not mirrored, the value is returned as is.
updates.append(
fn(v, *distribute_utils.select_replica_mirrored(i, args),
fn(value, *distribute_utils.select_replica_mirrored(i, args),
**distribute_utils.select_replica_mirrored(i, kwargs)))
return distribute_utils.update_regroup(self, updates, group)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
@ -64,14 +66,17 @@ def get_tpu_cluster_resolver():
return resolver
def get_tpu_strategy():
def get_tpu_strategy(enable_packed_var=False):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_strategy_util.initialize_tpu_system(resolver)
return tpu_lib.TPUStrategy(resolver)
strategy = tpu_lib.TPUStrategy(resolver)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
return strategy
class TPUStrategyTest(test.TestCase):
# TPU tests which don't use TPUStrategy.
class TPUTest(test.TestCase):
def test_multiple_initialize_system(self):
resolver = get_tpu_cluster_resolver()
@ -82,177 +87,6 @@ class TPUStrategyTest(test.TestCase):
tpu_strategy_util.initialize_tpu_system(resolver)
self.assertRegex(str(mock_log.call_args), "already been initialized")
def test_sequential_experimental_runs(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Computation replicated to all cores.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=2)
strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
# Computation on the 1st core.
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
strategy2 = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([2., 2.],)))
outputs2 = strategy2.run(
computation, args=([outputs[0]],))
return outputs2
self.assertAllEqual([[16., 16.]], train_step())
def test_device_switch_case(self):
strategy = get_tpu_strategy()
with strategy.scope():
a = variables.Variable(1)
inference_iteration = variables.Variable(-1)
def inference_fn(x, i):
return a + x + i
@def_function.function
def run_inference(x):
def do_inference(device, inference_fn, i):
with ops.device(device):
return inference_fn(x, i)
branch_fns = {
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
}
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
return control_flow_ops.switch_case(branch_index, branch_fns)
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
def test_recover_from_compilation_failures(self):
# TODO(b/148150981): Stop skipping this test once recovery works
# for non-local TPU.
if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy()
@def_function.function
def compilation_failure_run():
def computation():
return random_ops.random_gamma([10], [0.5, 1.5])
return strategy.run(computation)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"TPU compilation failed"):
compilation_failure_run()
@def_function.function
def good_run():
def computation():
return random_ops.random_normal([10])
return strategy.run(computation)
good_run()
def test_dynamic_shape_with_outside_compilation_failure(self):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
_, inputs = inputs
return math_ops.reduce_sum(inputs)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
all_core_strategy = tpu_lib.TPUStrategy(resolver)
with all_core_strategy.scope():
v = variables.Variable(0.0,
aggregation=variables.VariableAggregation.MEAN)
# Computation on the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
# Computation on the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
@def_function.function
def train_step():
def step_fn():
return v + 1.0
all_core_strategy.run(step_fn)
r1 = first_core_strategy.run(step_fn)
r2 = second_core_strategy.run(step_fn)
return r1 + r2
train_step()
self.assertAllEqual(2., train_step())
def test_worker_devices_on_subset_cores(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Strategy for the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
# Strategy for the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
self.assertLen(first_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
"device:TPU:0")
self.assertLen(second_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
"device:TPU:1")
def test_tpu_tf_function_same_device(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(1)
@ -288,8 +122,194 @@ class TPUStrategyTest(test.TestCase):
result = bar() + 1
self.assertAllEqual(result, 2)
def test_control_output_in_while_body_fn(self):
strategy = get_tpu_strategy()
@parameterized.named_parameters([("PackedVar", True), ("", False)])
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
def test_sequential_experimental_runs(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Computation replicated to all cores.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=2)
strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
# Computation on the 1st core.
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
strategy2 = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([2., 2.],)))
outputs2 = strategy2.run(
computation, args=([outputs[0]],))
return outputs2
self.assertAllEqual([[16., 16.]], train_step())
def test_device_switch_case(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
a = variables.Variable(1)
inference_iteration = variables.Variable(-1)
def inference_fn(x, i):
return a + x + i
@def_function.function
def run_inference(x):
def do_inference(device, inference_fn, i):
with ops.device(device):
return inference_fn(x, i)
branch_fns = {
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
}
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
return control_flow_ops.switch_case(branch_index, branch_fns)
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
def test_recover_from_compilation_failures(self, enable_packed_var):
# TODO(b/148150981): Stop skipping this test once recovery works
# for non-local TPU.
if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def compilation_failure_run():
def computation():
return random_ops.random_gamma([10], [0.5, 1.5])
return strategy.run(computation)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"TPU compilation failed"):
compilation_failure_run()
@def_function.function
def good_run():
def computation():
return random_ops.random_normal([10])
return strategy.run(computation)
good_run()
def test_dynamic_shape_with_outside_compilation_failure(
self, enable_packed_var):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
_, inputs = inputs
return math_ops.reduce_sum(inputs)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
all_core_strategy = tpu_lib.TPUStrategy(resolver)
all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
with all_core_strategy.scope():
v = variables.Variable(0.0,
aggregation=variables.VariableAggregation.MEAN)
# Computation on the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Computation on the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
@def_function.function
def train_step():
def step_fn():
return v + 1.0
all_core_strategy.run(step_fn)
r1 = first_core_strategy.run(step_fn)
r2 = second_core_strategy.run(step_fn)
return r1 + r2
train_step()
self.assertAllEqual(2., train_step())
def test_worker_devices_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Strategy for the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Strategy for the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategy(
resolver, device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
self.assertLen(first_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
"device:TPU:0")
self.assertLen(second_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
"device:TPU:1")
def test_control_output_in_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(
@ -307,8 +327,8 @@ class TPUStrategyTest(test.TestCase):
train_step()
self.assertEqual(2.0, v.numpy())
def test_cluster_in_graph_and_while_body_fn(self):
strategy = get_tpu_strategy()
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
@ -328,8 +348,8 @@ class TPUStrategyTest(test.TestCase):
sum_val = train_step().numpy().astype(float)
self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
def test_two_clusters_with_same_fn(self):
strategy = get_tpu_strategy()
def test_two_clusters_with_same_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def foo(x):
@ -342,8 +362,8 @@ class TPUStrategyTest(test.TestCase):
bar(1)
def test_using_external_variable_inside_tf_function(self):
strategy = get_tpu_strategy()
def test_using_external_variable_inside_tf_function(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
@ -364,8 +384,8 @@ class TPUStrategyTest(test.TestCase):
strategy.experimental_local_results(train_step(next(input_iterator))))
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
def test_all_reduce_on_sync_on_read_variable(self):
strategy = get_tpu_strategy()
def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
strategy.num_replicas_in_sync, drop_remainder=True)
@ -404,8 +424,8 @@ class TPUStrategyTest(test.TestCase):
self.assertAllEqual((0.,), w.read_value())
# TODO(b/140633529): Re-enable the test.
def disable_test_experimental_run_output_on_device(self):
strategy = get_tpu_strategy()
def disable_test_experimental_run_output_on_device(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
def computation(x):
return math_ops.square(x)
@ -423,8 +443,8 @@ class TPUStrategyTest(test.TestCase):
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
results[1].backing_device)
def test_composite_input(self):
strategy = get_tpu_strategy()
def test_composite_input(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
@ -463,8 +483,9 @@ class TPUStrategyTest(test.TestCase):
self.assertAllEqual(result,
[[[0.0, 1.0], [3.0, 8.0]], [[0.0, 1.0], [3.0, 8.0]]])
def test_composite_input_dynamic_shapes_outside_compilation(self):
strategy = get_tpu_strategy()
def test_composite_input_dynamic_shapes_outside_compilation(
self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
@ -506,11 +527,11 @@ class TPUStrategyTest(test.TestCase):
result = sparse_lookup(dataset)
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
def test_per_device_tracing_of_mirrored_variables(self):
def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var):
# Define trace_count as a list to avoid python scoping error
trace_count = [0]
strategy = get_tpu_strategy()
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
variable = variables.Variable(0.0)
@ -527,7 +548,10 @@ class TPUStrategyTest(test.TestCase):
with strategy.scope():
update_variable.get_concrete_function()
self.assertEqual(trace_count[0], len(strategy.extended.worker_devices))
self.assertLen(strategy.extended.worker_devices, trace_count[0])
class TPUStrategyDataPrefetchTest(test.TestCase):
def test_prefetch_to_device_default(self):
strategy = get_tpu_strategy()

View File

@ -24,6 +24,7 @@ from __future__ import print_function
import contextlib
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
@ -46,15 +47,27 @@ def _maybe_enter_graph(tensor):
yield
@contextlib.contextmanager
def _maybe_on_device(var):
# Add a device scope for packed variables.
if isinstance(var, packed.PackedVarAndDevice):
with ops.device(var.device):
yield
else:
yield
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
del use_locking # Unused.
with _maybe_enter_graph(var.handle):
handle = var.handle
with _maybe_enter_graph(handle), _maybe_on_device(var):
op = raw_assign_fn(
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
handle,
ops.convert_to_tensor(value, dtype=var.dtype),
name=name)
with ops.control_dependencies([op]):
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
@ -97,23 +110,37 @@ class TPUVariableMixin(object):
@property
def handle(self):
"""The handle by which this variable can be accessed."""
# If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = enclosing_tpu_context()
if tpu_context is None or context.executing_eagerly():
return self._get_on_device_or_primary().handle
else:
return tpu_context.get_replicated_var_handle(self._handle_id,
self._values,
self._is_mirrored())
is_packed = self._packed_var is not None
val = self._values
if is_packed:
val = [self._packed_var]
return tpu_context.get_replicated_var_handle(self._handle_id, val,
self._is_mirrored(),
is_packed)
@property
def device(self):
return self.handle.device
def _read_variable_op(self):
"""Reads the value of this variable."""
if self.trainable:
tape.variable_accessed(self)
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
handle = self.handle
if getattr(handle, "is_packed", False):
# Add a device scope for a packed variable handle.
with ops.device(self._get_on_device_or_primary().device):
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
else:
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
def read_value(self):
if enclosing_tpu_context() is None:

View File

@ -472,6 +472,12 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
# variable.
self._var_policy = var_policy
@property
def _devices(self):
if self._packed_var is not None:
return tuple(d for d in self._packed_var.devices)
return tuple(v.device for v in self._values)
def is_initialized(self, name=None):
"""Identifies if all the component variables are initialized.
@ -482,6 +488,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
The op that evaluates to True or False depending on if all the
component variables are initialized.
"""
if self._packed_var is not None:
return self._packed_var.is_initialized()
result = self._primary.is_initialized()
# We iterate through the list of values except the last one to allow us to
# name the final `logical_and` op the same name that is passed by the user
@ -552,6 +560,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
def aggregation(self):
return self._aggregation
@property
def _packed_variable(self):
return self._packed_var
@property
def handle(self):
replica_id = values_util.get_current_replica_id_as_int()
@ -559,6 +571,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
raise ValueError("`handle` is not available outside the replica context"
" or a `tf.distribute.Strategy.update()` call.")
else:
if self._packed_var is not None:
return self._packed_var.handle
return self._values[replica_id].handle
def eval(self, session=None):
@ -607,6 +621,33 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
def _in_graph_mode(self):
return self._primary._in_graph_mode # pylint: disable=protected-access
def _get_replica(self, replica_id):
"""Returns the value on a device with the given replica_id."""
if self._packed_var is not None:
return self._packed_var.on_device(self._devices[replica_id])
return self._values[replica_id]
def _get(self):
"""Returns the value for the current device or raises a ValueError."""
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
return self._get_cross_replica()
else:
return self._get_replica(replica_id)
def _get_on_device_or_primary(self):
"""Returns value in same replica or device if possible, else the _primary."""
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
# Try to find a value on the current device.
current_device = device_util.canonicalize(device_util.current())
for i, value in enumerate(self._values):
if device_util.canonicalize(value.device) == current_device:
return self._get_replica(i)
return self._get_replica(0)
else:
return self._get_replica(replica_id)
def read_value(self):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return array_ops.identity(self._get())
@ -778,7 +819,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
if ds_context.in_cross_replica_context():
update_replica_id = distribute_lib.get_update_replica_id()
if update_replica_id is not None:
return update_fn(self._values[update_replica_id], value, **kwargs)
replica_value = self._get_replica(update_replica_id)
return update_fn(replica_value, value, **kwargs)
return self._update_cross_replica(update_fn, value, **kwargs)
else:
values_util.assert_replica_context(self.distribute_strategy)
@ -802,6 +844,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
obj_map[v] = new_obj
resource_map[v.handle] = new_obj.handle
obj_map[self] = new_obj
resource_map[self.handle] = new_obj.handle
resource_map[self] = new_obj.handle
return obj_map, resource_map
@ -835,6 +878,12 @@ class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
packed_var = self._mirrored_variable._packed_variable # pylint: disable=protected-access
if packed_var is not None:
return control_flow_ops.group(
tuple(
values_util.assign_on_device(d, packed_var, tensor)
for d in packed_var.devices))
return control_flow_ops.group(
tuple(
values_util.assign_on_device(v.device, v, tensor)
@ -1013,7 +1062,7 @@ class SyncOnReadVariable(DistributedVariable):
def _get_cross_replica(self):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return self._primary
return self._get_replica(0)
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return self._distribute_strategy.reduce(

View File

@ -42,7 +42,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
@ -234,11 +233,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
# TODO(b/137795644): support CentralStroageStrategy
# strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]
))
mode=["eager"]))
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
@ -259,11 +258,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
# TODO(b/137795644): support CentralStroageStrategy
# strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]
))
mode=["eager"]))
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
@ -384,6 +383,16 @@ def _make_mirrored():
return mirrored
def mirrored_and_tpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"])
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
def _is_per_replica(self, result, expected, klass=values.PerReplica):
@ -563,6 +572,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
synchronization=[
@ -708,29 +718,40 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
def testPackedVariable(self, distribution, synchronization, aggregation):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.tpu_strategy,
],
mode=["eager"]))
class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
def testPackedVariable(self, distribution):
with distribution.scope():
v0 = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
if not isinstance(v0, values.DistributedVariable):
self.skipTest("This test doesn't apply to non DistributedVariables")
self.assertEqual(v0._packed_var, None)
device_type = device.DeviceSpec.from_string(v0._devices[0]).device_type
for d in v0._devices:
if device.DeviceSpec.from_string(d).device_type != device_type:
self.skipTest("Packing variables on devices of different types "
"is not supported yet.")
v0 = variables_lib.Variable(0.)
self.assertIsNone(v0._packed_var)
distribution._enable_packed_variable_in_eager_mode = True
with distribution.scope():
v1 = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
if ops.executing_eagerly_outside_functions():
v1 = variables_lib.Variable(0)
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
else:
self.assertEqual(v1._packed_var, None)
devices = v1._devices
for i in range(1, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
v1.assign(i)
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[0])
self.assertEqual(self.evaluate(val.read_value()), 0)
for i in range(0, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[i])
self.assertEqual(self.evaluate(val.read_value()), i)
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
@ -920,6 +941,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["eager"]))
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
@ -943,6 +965,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"]))
def testValueInReplicaContext(self, distribution):
@ -968,6 +991,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"]))
def testAssignOutOfScope(self, distribution):
@ -1041,6 +1065,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["eager"]))
def testInitializedToSameValueInsideEagerRun(self, distribution):
@ -1066,6 +1091,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"]))
def testAggregationOnlyFirstReplica(self, distribution):
@ -1093,6 +1119,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["eager"]))
def testInitScope(self, distribution):
@ -1143,13 +1170,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
distribution.experimental_local_results(distribution.run(add)))
self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"]))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testAssignAdd(self, distribution):
with distribution.scope():
v = variable_scope.variable(
@ -1456,15 +1477,6 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
self.assertEqual(2., self.evaluate(add1(replica_local)))
def mirrored_and_tpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"])
# TODO(b/144432582): Add variable aggregation type to combinations to simplify
# tests.
def strategy_and_run_tf_function_combinations():
@ -1478,6 +1490,7 @@ def strategy_and_run_tf_function_combinations():
experimental_run_tf_function=[True, False]) + combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"],
experimental_run_tf_function=[True])

View File

@ -61,8 +61,14 @@ def on_write_assign_sub(var, value, use_locking=False, name=None,
def assign_on_each_device(var, assign_func, value, read_value):
update = control_flow_ops.group(
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
"""Update the variable on each replica with the given assign_func and value."""
if var._packed_variable is not None: # pylint: disable=protected-access
update = control_flow_ops.group(
tuple(
assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access
else:
update = control_flow_ops.group(
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
if not read_value:
return update
with ops.control_dependencies([update] if update else []):
@ -104,7 +110,7 @@ def on_read_assign_cross_replica(var, value, read_value=True):
# TODO(anjs): Should this be over all the replicas in sync since we
# call `reduce` on the variable during read?
if var.aggregation == vs.VariableAggregation.SUM:
tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access
tensor = math_ops.cast(tensor / len(var._devices), var.dtype) # pylint: disable=protected-access
return assign_on_each_device(var, assign_on_device, tensor,
read_value)

View File

@ -298,7 +298,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot
self._replicated_vars = {}
def get_replicated_var_handle(self, name, vars_, is_mirrored=False):
def get_replicated_var_handle(self, name, vars_, is_mirrored=False,
is_packed=False):
"""Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation
@ -309,6 +310,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
vars_: The replicated TPU variables.
is_mirrored: Whether the variables are mirrored, which guarantees the
values in each replica are always the same.
is_packed: Whether the replicated variables are packed into one variable.
Returns:
The handle of the TPU replicated input node.
@ -320,7 +322,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if handle is not None:
return handle
if device_assignment is not None:
if device_assignment is not None and not is_packed:
# Find a variable copy for each replica in the device assignment.
# Note that the order of devices for replicas for the variable and the
# device assignment might not match.
@ -356,7 +358,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
name=name + "/handle",
is_mirrored_variable=is_mirrored)
is_mirrored_variable=is_mirrored,
is_packed=is_packed)
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
self._replicated_vars[name] = handle