Support packed variable in DistributedVariable. Add an option to enable packed variable in TPUStrategy.
PiperOrigin-RevId: 317234665 Change-Id: I09e806cb8261815cd87a6d98817556dd8f7e8ed7
This commit is contained in:
parent
4d54ef3139
commit
7e6e549c46
@ -654,6 +654,7 @@ tpu_py_test(
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:remote",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -787,6 +788,7 @@ py_library(
|
||||
name = "tpu_values",
|
||||
srcs = ["tpu_values.py"],
|
||||
deps = [
|
||||
":packed_distributed_variable",
|
||||
":values",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -1602,7 +1604,7 @@ distribute_py_test(
|
||||
srcs = ["saved_model_save_load_test.py"],
|
||||
full_precision = True,
|
||||
main = "saved_model_save_load_test.py",
|
||||
shard_count = 5,
|
||||
shard_count = 7,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
@ -1635,7 +1637,7 @@ distribute_py_test(
|
||||
srcs = ["saved_model_mixed_api_test.py"],
|
||||
full_precision = True,
|
||||
main = "saved_model_mixed_api_test.py",
|
||||
shard_count = 5,
|
||||
shard_count = 7,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
|
@ -103,6 +103,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]))
|
||||
@ -138,6 +139,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]))
|
||||
|
@ -197,7 +197,8 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testNestedOutput(self, distribution):
|
||||
@ -748,6 +749,10 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMultiDeviceDataCapturedFunction(self, distribution):
|
||||
if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
|
||||
self.skipTest(
|
||||
"Dataset captured function doesn't support packed tensors yet "
|
||||
"(b/145922293).")
|
||||
inputs = constant_op.constant([2., 3.])
|
||||
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
|
||||
input_iterator = iter(
|
||||
|
@ -148,6 +148,9 @@ def select_replica_mirrored(replica_id, structured):
|
||||
raise TypeError(
|
||||
"Expected value to be mirrored across replicas: %s in %s." %
|
||||
(x, structured))
|
||||
packed_var = getattr(x, "_packed_variable", None)
|
||||
if packed_var is not None:
|
||||
return packed_var
|
||||
return x.values[replica_id]
|
||||
else:
|
||||
return x
|
||||
|
@ -42,7 +42,7 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
|
||||
name: Optional name for the variable. Defaults to `'Variable'` and gets
|
||||
uniquified automatically.
|
||||
"""
|
||||
if not context.executing_eagerly():
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
raise ValueError(
|
||||
"PackedDistributedVariable should be created in eager mode.")
|
||||
if not distributed_variables:
|
||||
@ -84,6 +84,9 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
|
||||
def devices(self):
|
||||
return self._devices
|
||||
|
||||
def on_device(self, device):
|
||||
return PackedVarAndDevice(self, device)
|
||||
|
||||
def get_var_on_device(self, device):
|
||||
for i, d in enumerate(self._devices):
|
||||
if d == device:
|
||||
@ -100,7 +103,10 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._handle
|
||||
if context.executing_eagerly():
|
||||
return self.get_var_on_current_device().handle
|
||||
else:
|
||||
return self._handle
|
||||
|
||||
def _read_variable_op(self):
|
||||
if context.executing_eagerly():
|
||||
@ -269,7 +275,8 @@ class PackedVarAndDevice(object):
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._var.handle
|
||||
with ops.device(self._device):
|
||||
return self._var.handle
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
|
@ -46,7 +46,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
||||
v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')
|
||||
|
||||
packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
|
||||
self.assertTrue(packed_var.handle.is_packed)
|
||||
self.assertFalse(packed_var.handle.is_packed)
|
||||
self.assertTrue(packed_var.is_initialized)
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
@ -61,6 +61,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def update_var():
|
||||
self.assertTrue(packed_var.handle.is_packed)
|
||||
with ops.device('/cpu:0'):
|
||||
packed_var.assign_add(3.0).assign_sub(1.0)
|
||||
read0 = packed_var.value()
|
||||
@ -85,7 +86,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
||||
|
||||
packed_var0 = packed_distributed_variable.PackedVarAndDevice(
|
||||
packed_var, device0)
|
||||
self.assertTrue(packed_var0.handle.is_packed)
|
||||
self.assertFalse(packed_var0.handle.is_packed)
|
||||
self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)
|
||||
|
||||
packed_var1 = packed_distributed_variable.PackedVarAndDevice(
|
||||
@ -94,6 +95,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
self.assertTrue(packed_var.handle.is_packed)
|
||||
var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0)
|
||||
var0.assign_add(3.0)
|
||||
var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1)
|
||||
|
@ -58,6 +58,7 @@ strategies = [
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
]
|
||||
|
||||
|
@ -53,7 +53,11 @@ _did_connect_to_cluster = False
|
||||
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
|
||||
def _get_tpu_strategy_creator(steps_per_run,
|
||||
use_single_core=False,
|
||||
enable_packed_variable=False,
|
||||
**kwargs):
|
||||
|
||||
def _create_tpu_strategy():
|
||||
global _did_connect_to_cluster
|
||||
|
||||
@ -87,10 +91,13 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
|
||||
|
||||
# Steps per run is only supported in TF 1.x
|
||||
if tf2.enabled():
|
||||
return tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
|
||||
strategy = tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
|
||||
else:
|
||||
return tpu_lib.TPUStrategyV1(resolver, steps_per_run,
|
||||
device_assignment, **kwargs)
|
||||
strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
|
||||
device_assignment, **kwargs)
|
||||
strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access
|
||||
return strategy
|
||||
|
||||
return _create_tpu_strategy
|
||||
|
||||
|
||||
@ -117,6 +124,10 @@ one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
|
||||
required_gpus=1)
|
||||
tpu_strategy = combinations.NamedDistribution(
|
||||
"TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
|
||||
tpu_strategy_packed_var = combinations.NamedDistribution(
|
||||
"TPUPackedVar",
|
||||
_get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True),
|
||||
required_tpu=True)
|
||||
tpu_strategy_one_step = combinations.NamedDistribution(
|
||||
"TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
|
||||
tpu_strategy_one_core = combinations.NamedDistribution(
|
||||
@ -286,6 +297,7 @@ strategies_minus_default_and_tpu = [
|
||||
tpu_strategies = [
|
||||
tpu_strategy, # steps_per_run=2
|
||||
tpu_strategy_one_step,
|
||||
tpu_strategy_packed_var,
|
||||
cloud_tpu_strategy,
|
||||
]
|
||||
|
||||
|
@ -141,6 +141,10 @@ class TPUStrategy(distribute_lib.Strategy):
|
||||
"num_workers").set(self.extended.num_hosts)
|
||||
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
||||
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
|
||||
# Packed variable is used to reduce the overhead of function execution.
|
||||
# For a DistributedVariable, only one variable handle is captured into a
|
||||
# function graph. It's only supported in eager mode.
|
||||
self._enable_packed_variable_in_eager_mode = False
|
||||
|
||||
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||
# can use the default implementation.
|
||||
@ -185,6 +189,10 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
|
||||
"num_workers").set(self.extended.num_hosts)
|
||||
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
||||
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
|
||||
# Packed variable is used to reduce the overhead of function execution.
|
||||
# For a DistributedVariable, only one variable handle is captured into a
|
||||
# function graph. It's only supported in eager mode.
|
||||
self._enable_packed_variable_in_eager_mode = False
|
||||
|
||||
@property
|
||||
def steps_per_run(self):
|
||||
@ -671,20 +679,29 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, value, destinations, self._num_replicas_in_sync)
|
||||
|
||||
value_list = value.values
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(
|
||||
value,
|
||||
values.DistributedVariable) and value._packed_variable is not None:
|
||||
value_list = tuple(
|
||||
value._packed_variable.on_device(d)
|
||||
for d in value._packed_variable.devices)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Currently XLA op by op mode has a limit for the number of inputs for a
|
||||
# single op, thus we break one `add_n` op into a group of `add_n` ops to
|
||||
# work around the constraint.
|
||||
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
|
||||
if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
|
||||
output = math_ops.add_n(value.values)
|
||||
output = math_ops.add_n(value_list)
|
||||
else:
|
||||
output = array_ops.zeros_like(
|
||||
value.values[0], dtype=value.values[0].dtype)
|
||||
for i in range(0, len(value.values), _XLA_OP_BY_OP_INPUTS_LIMIT):
|
||||
output += math_ops.add_n(value.values[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
|
||||
output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
|
||||
for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
|
||||
output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
|
||||
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
output *= (1. / len(value.values))
|
||||
output *= (1. / len(value_list))
|
||||
|
||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||
|
||||
@ -710,17 +727,28 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
else:
|
||||
return (fn(var, *args, **kwargs),)
|
||||
|
||||
# Otherwise, we revert to MirroredStrategy behavior and update each variable
|
||||
# directly.
|
||||
# Otherwise, we revert to MirroredStrategy behavior and update the variable
|
||||
# on each replica directly.
|
||||
updates = []
|
||||
for i, v in enumerate(var.values):
|
||||
values_and_devices = []
|
||||
packed_var = var._packed_variable # pylint: disable=protected-access
|
||||
if packed_var is not None:
|
||||
for device in packed_var.devices:
|
||||
values_and_devices.append((packed_var, device))
|
||||
else:
|
||||
for value in var.values:
|
||||
values_and_devices.append((value, value.device))
|
||||
|
||||
for i, value_and_device in enumerate(values_and_devices):
|
||||
value = value_and_device[0]
|
||||
device = value_and_device[1]
|
||||
name = "update_%d" % i
|
||||
with ops.device(v.device), \
|
||||
with ops.device(device), \
|
||||
distribute_lib.UpdateContext(i), \
|
||||
ops.name_scope(name):
|
||||
# If args and kwargs are not mirrored, the value is returned as is.
|
||||
updates.append(
|
||||
fn(v, *distribute_utils.select_replica_mirrored(i, args),
|
||||
fn(value, *distribute_utils.select_replica_mirrored(i, args),
|
||||
**distribute_utils.select_replica_mirrored(i, kwargs)))
|
||||
return distribute_utils.update_regroup(self, updates, group)
|
||||
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -64,14 +66,17 @@ def get_tpu_cluster_resolver():
|
||||
return resolver
|
||||
|
||||
|
||||
def get_tpu_strategy():
|
||||
def get_tpu_strategy(enable_packed_var=False):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
return tpu_lib.TPUStrategy(resolver)
|
||||
strategy = tpu_lib.TPUStrategy(resolver)
|
||||
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
|
||||
return strategy
|
||||
|
||||
|
||||
class TPUStrategyTest(test.TestCase):
|
||||
# TPU tests which don't use TPUStrategy.
|
||||
class TPUTest(test.TestCase):
|
||||
|
||||
def test_multiple_initialize_system(self):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
@ -82,177 +87,6 @@ class TPUStrategyTest(test.TestCase):
|
||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
self.assertRegex(str(mock_log.call_args), "already been initialized")
|
||||
|
||||
def test_sequential_experimental_runs(self):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
# Computation replicated to all cores.
|
||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=2)
|
||||
strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
|
||||
# Computation on the 1st core.
|
||||
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=1)
|
||||
strategy2 = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment2)
|
||||
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
outputs = strategy.experimental_local_results(
|
||||
strategy.run(computation, args=([2., 2.],)))
|
||||
outputs2 = strategy2.run(
|
||||
computation, args=([outputs[0]],))
|
||||
return outputs2
|
||||
|
||||
self.assertAllEqual([[16., 16.]], train_step())
|
||||
|
||||
def test_device_switch_case(self):
|
||||
strategy = get_tpu_strategy()
|
||||
with strategy.scope():
|
||||
a = variables.Variable(1)
|
||||
|
||||
inference_iteration = variables.Variable(-1)
|
||||
|
||||
def inference_fn(x, i):
|
||||
return a + x + i
|
||||
|
||||
@def_function.function
|
||||
def run_inference(x):
|
||||
|
||||
def do_inference(device, inference_fn, i):
|
||||
with ops.device(device):
|
||||
return inference_fn(x, i)
|
||||
|
||||
branch_fns = {
|
||||
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
|
||||
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
|
||||
}
|
||||
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
|
||||
return control_flow_ops.switch_case(branch_index, branch_fns)
|
||||
|
||||
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
|
||||
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
|
||||
|
||||
def test_recover_from_compilation_failures(self):
|
||||
# TODO(b/148150981): Stop skipping this test once recovery works
|
||||
# for non-local TPU.
|
||||
if FLAGS.tpu:
|
||||
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
|
||||
|
||||
# Disable automatic outside compilation.
|
||||
config.set_soft_device_placement(False)
|
||||
strategy = get_tpu_strategy()
|
||||
|
||||
@def_function.function
|
||||
def compilation_failure_run():
|
||||
|
||||
def computation():
|
||||
return random_ops.random_gamma([10], [0.5, 1.5])
|
||||
|
||||
return strategy.run(computation)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"TPU compilation failed"):
|
||||
compilation_failure_run()
|
||||
|
||||
@def_function.function
|
||||
def good_run():
|
||||
|
||||
def computation():
|
||||
return random_ops.random_normal([10])
|
||||
|
||||
return strategy.run(computation)
|
||||
|
||||
good_run()
|
||||
|
||||
def test_dynamic_shape_with_outside_compilation_failure(self):
|
||||
# Enable automatic outside compilation.
|
||||
config.set_soft_device_placement(True)
|
||||
strategy = get_tpu_strategy()
|
||||
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
|
||||
2, drop_remainder=False)
|
||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
iterator = iter(dataset)
|
||||
|
||||
@def_function.function
|
||||
def train_fn(iterator):
|
||||
|
||||
def step_fn(inputs):
|
||||
_, inputs = inputs
|
||||
return math_ops.reduce_sum(inputs)
|
||||
|
||||
return strategy.experimental_local_results(
|
||||
strategy.run(step_fn, args=(next(iterator),)))
|
||||
|
||||
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
|
||||
logging.info(train_fn(iterator))
|
||||
|
||||
def test_computation_on_subset_cores(self):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
all_core_strategy = tpu_lib.TPUStrategy(resolver)
|
||||
|
||||
with all_core_strategy.scope():
|
||||
v = variables.Variable(0.0,
|
||||
aggregation=variables.VariableAggregation.MEAN)
|
||||
|
||||
# Computation on the 1st core.
|
||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=1)
|
||||
first_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
|
||||
# Computation on the 2nd core.
|
||||
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
||||
topology, [[[0, 0, 0, 1]]])
|
||||
second_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment2)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
|
||||
def step_fn():
|
||||
return v + 1.0
|
||||
|
||||
all_core_strategy.run(step_fn)
|
||||
r1 = first_core_strategy.run(step_fn)
|
||||
r2 = second_core_strategy.run(step_fn)
|
||||
return r1 + r2
|
||||
|
||||
train_step()
|
||||
self.assertAllEqual(2., train_step())
|
||||
|
||||
def test_worker_devices_on_subset_cores(self):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
|
||||
# Strategy for the 1st core.
|
||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=1)
|
||||
first_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
|
||||
# Strategy for the 2nd core.
|
||||
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
||||
topology, [[[0, 0, 0, 1]]])
|
||||
second_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment2)
|
||||
|
||||
self.assertLen(first_core_strategy.extended.worker_devices, 1)
|
||||
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
|
||||
"device:TPU:0")
|
||||
|
||||
self.assertLen(second_core_strategy.extended.worker_devices, 1)
|
||||
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
|
||||
"device:TPU:1")
|
||||
|
||||
def test_tpu_tf_function_same_device(self):
|
||||
with ops.device("/device:TPU:0"):
|
||||
a = variables.Variable(1)
|
||||
@ -288,8 +122,194 @@ class TPUStrategyTest(test.TestCase):
|
||||
result = bar() + 1
|
||||
self.assertAllEqual(result, 2)
|
||||
|
||||
def test_control_output_in_while_body_fn(self):
|
||||
strategy = get_tpu_strategy()
|
||||
|
||||
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
||||
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_sequential_experimental_runs(self, enable_packed_var):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
# Computation replicated to all cores.
|
||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=2)
|
||||
strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
|
||||
|
||||
# Computation on the 1st core.
|
||||
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=1)
|
||||
strategy2 = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment2)
|
||||
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
outputs = strategy.experimental_local_results(
|
||||
strategy.run(computation, args=([2., 2.],)))
|
||||
outputs2 = strategy2.run(
|
||||
computation, args=([outputs[0]],))
|
||||
return outputs2
|
||||
|
||||
self.assertAllEqual([[16., 16.]], train_step())
|
||||
|
||||
def test_device_switch_case(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
with strategy.scope():
|
||||
a = variables.Variable(1)
|
||||
|
||||
inference_iteration = variables.Variable(-1)
|
||||
|
||||
def inference_fn(x, i):
|
||||
return a + x + i
|
||||
|
||||
@def_function.function
|
||||
def run_inference(x):
|
||||
|
||||
def do_inference(device, inference_fn, i):
|
||||
with ops.device(device):
|
||||
return inference_fn(x, i)
|
||||
|
||||
branch_fns = {
|
||||
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
|
||||
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
|
||||
}
|
||||
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
|
||||
return control_flow_ops.switch_case(branch_index, branch_fns)
|
||||
|
||||
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
|
||||
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
|
||||
|
||||
def test_recover_from_compilation_failures(self, enable_packed_var):
|
||||
# TODO(b/148150981): Stop skipping this test once recovery works
|
||||
# for non-local TPU.
|
||||
if FLAGS.tpu:
|
||||
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
|
||||
|
||||
# Disable automatic outside compilation.
|
||||
config.set_soft_device_placement(False)
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
@def_function.function
|
||||
def compilation_failure_run():
|
||||
|
||||
def computation():
|
||||
return random_ops.random_gamma([10], [0.5, 1.5])
|
||||
|
||||
return strategy.run(computation)
|
||||
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"TPU compilation failed"):
|
||||
compilation_failure_run()
|
||||
|
||||
@def_function.function
|
||||
def good_run():
|
||||
|
||||
def computation():
|
||||
return random_ops.random_normal([10])
|
||||
|
||||
return strategy.run(computation)
|
||||
|
||||
good_run()
|
||||
|
||||
def test_dynamic_shape_with_outside_compilation_failure(
|
||||
self, enable_packed_var):
|
||||
# Enable automatic outside compilation.
|
||||
config.set_soft_device_placement(True)
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
|
||||
2, drop_remainder=False)
|
||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
iterator = iter(dataset)
|
||||
|
||||
@def_function.function
|
||||
def train_fn(iterator):
|
||||
|
||||
def step_fn(inputs):
|
||||
_, inputs = inputs
|
||||
return math_ops.reduce_sum(inputs)
|
||||
|
||||
return strategy.experimental_local_results(
|
||||
strategy.run(step_fn, args=(next(iterator),)))
|
||||
|
||||
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
|
||||
logging.info(train_fn(iterator))
|
||||
|
||||
def test_computation_on_subset_cores(self, enable_packed_var):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
all_core_strategy = tpu_lib.TPUStrategy(resolver)
|
||||
all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
|
||||
|
||||
with all_core_strategy.scope():
|
||||
v = variables.Variable(0.0,
|
||||
aggregation=variables.VariableAggregation.MEAN)
|
||||
|
||||
# Computation on the 1st core.
|
||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=1)
|
||||
first_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
first_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||
enable_packed_var)
|
||||
|
||||
# Computation on the 2nd core.
|
||||
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
||||
topology, [[[0, 0, 0, 1]]])
|
||||
second_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment2)
|
||||
second_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||
enable_packed_var)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
|
||||
def step_fn():
|
||||
return v + 1.0
|
||||
|
||||
all_core_strategy.run(step_fn)
|
||||
r1 = first_core_strategy.run(step_fn)
|
||||
r2 = second_core_strategy.run(step_fn)
|
||||
return r1 + r2
|
||||
|
||||
train_step()
|
||||
self.assertAllEqual(2., train_step())
|
||||
|
||||
def test_worker_devices_on_subset_cores(self, enable_packed_var):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
|
||||
# Strategy for the 1st core.
|
||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||
topology, num_replicas=1)
|
||||
first_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
first_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||
enable_packed_var)
|
||||
|
||||
# Strategy for the 2nd core.
|
||||
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
||||
topology, [[[0, 0, 0, 1]]])
|
||||
second_core_strategy = tpu_lib.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment2)
|
||||
second_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||
enable_packed_var)
|
||||
|
||||
self.assertLen(first_core_strategy.extended.worker_devices, 1)
|
||||
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
|
||||
"device:TPU:0")
|
||||
|
||||
self.assertLen(second_core_strategy.extended.worker_devices, 1)
|
||||
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
|
||||
"device:TPU:1")
|
||||
|
||||
def test_control_output_in_while_body_fn(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
with strategy.scope():
|
||||
v = variables.Variable(
|
||||
@ -307,8 +327,8 @@ class TPUStrategyTest(test.TestCase):
|
||||
train_step()
|
||||
self.assertEqual(2.0, v.numpy())
|
||||
|
||||
def test_cluster_in_graph_and_while_body_fn(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
@ -328,8 +348,8 @@ class TPUStrategyTest(test.TestCase):
|
||||
sum_val = train_step().numpy().astype(float)
|
||||
self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
|
||||
|
||||
def test_two_clusters_with_same_fn(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def test_two_clusters_with_same_fn(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
@def_function.function
|
||||
def foo(x):
|
||||
@ -342,8 +362,8 @@ class TPUStrategyTest(test.TestCase):
|
||||
|
||||
bar(1)
|
||||
|
||||
def test_using_external_variable_inside_tf_function(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def test_using_external_variable_inside_tf_function(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
dataset = dataset_ops.Dataset.range(
|
||||
strategy.num_replicas_in_sync * 2,
|
||||
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
|
||||
@ -364,8 +384,8 @@ class TPUStrategyTest(test.TestCase):
|
||||
strategy.experimental_local_results(train_step(next(input_iterator))))
|
||||
|
||||
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
|
||||
def test_all_reduce_on_sync_on_read_variable(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
dataset = dataset_ops.Dataset.range(
|
||||
strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
|
||||
strategy.num_replicas_in_sync, drop_remainder=True)
|
||||
@ -404,8 +424,8 @@ class TPUStrategyTest(test.TestCase):
|
||||
self.assertAllEqual((0.,), w.read_value())
|
||||
|
||||
# TODO(b/140633529): Re-enable the test.
|
||||
def disable_test_experimental_run_output_on_device(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def disable_test_experimental_run_output_on_device(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
@ -423,8 +443,8 @@ class TPUStrategyTest(test.TestCase):
|
||||
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
|
||||
results[1].backing_device)
|
||||
|
||||
def test_composite_input(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def test_composite_input(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
if strategy.num_replicas_in_sync != 2:
|
||||
self.skipTest("Test assumes two replicas.")
|
||||
|
||||
@ -463,8 +483,9 @@ class TPUStrategyTest(test.TestCase):
|
||||
self.assertAllEqual(result,
|
||||
[[[0.0, 1.0], [3.0, 8.0]], [[0.0, 1.0], [3.0, 8.0]]])
|
||||
|
||||
def test_composite_input_dynamic_shapes_outside_compilation(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def test_composite_input_dynamic_shapes_outside_compilation(
|
||||
self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
if strategy.num_replicas_in_sync != 2:
|
||||
self.skipTest("Test assumes two replicas.")
|
||||
|
||||
@ -506,11 +527,11 @@ class TPUStrategyTest(test.TestCase):
|
||||
result = sparse_lookup(dataset)
|
||||
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
|
||||
|
||||
def test_per_device_tracing_of_mirrored_variables(self):
|
||||
def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var):
|
||||
# Define trace_count as a list to avoid python scoping error
|
||||
trace_count = [0]
|
||||
|
||||
strategy = get_tpu_strategy()
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
with strategy.scope():
|
||||
variable = variables.Variable(0.0)
|
||||
|
||||
@ -527,7 +548,10 @@ class TPUStrategyTest(test.TestCase):
|
||||
|
||||
with strategy.scope():
|
||||
update_variable.get_concrete_function()
|
||||
self.assertEqual(trace_count[0], len(strategy.extended.worker_devices))
|
||||
self.assertLen(strategy.extended.worker_devices, trace_count[0])
|
||||
|
||||
|
||||
class TPUStrategyDataPrefetchTest(test.TestCase):
|
||||
|
||||
def test_prefetch_to_device_default(self):
|
||||
strategy = get_tpu_strategy()
|
||||
|
@ -24,6 +24,7 @@ from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
@ -46,15 +47,27 @@ def _maybe_enter_graph(tensor):
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_on_device(var):
|
||||
# Add a device scope for packed variables.
|
||||
if isinstance(var, packed.PackedVarAndDevice):
|
||||
with ops.device(var.device):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
||||
|
||||
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
|
||||
del use_locking # Unused.
|
||||
|
||||
with _maybe_enter_graph(var.handle):
|
||||
handle = var.handle
|
||||
with _maybe_enter_graph(handle), _maybe_on_device(var):
|
||||
op = raw_assign_fn(
|
||||
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
|
||||
|
||||
handle,
|
||||
ops.convert_to_tensor(value, dtype=var.dtype),
|
||||
name=name)
|
||||
with ops.control_dependencies([op]):
|
||||
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
||||
|
||||
@ -97,23 +110,37 @@ class TPUVariableMixin(object):
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
"""The handle by which this variable can be accessed."""
|
||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||
tpu_context = enclosing_tpu_context()
|
||||
if tpu_context is None or context.executing_eagerly():
|
||||
return self._get_on_device_or_primary().handle
|
||||
else:
|
||||
return tpu_context.get_replicated_var_handle(self._handle_id,
|
||||
self._values,
|
||||
self._is_mirrored())
|
||||
is_packed = self._packed_var is not None
|
||||
val = self._values
|
||||
if is_packed:
|
||||
val = [self._packed_var]
|
||||
|
||||
return tpu_context.get_replicated_var_handle(self._handle_id, val,
|
||||
self._is_mirrored(),
|
||||
is_packed)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.handle.device
|
||||
|
||||
def _read_variable_op(self):
|
||||
"""Reads the value of this variable."""
|
||||
if self.trainable:
|
||||
tape.variable_accessed(self)
|
||||
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
|
||||
|
||||
handle = self.handle
|
||||
if getattr(handle, "is_packed", False):
|
||||
# Add a device scope for a packed variable handle.
|
||||
with ops.device(self._get_on_device_or_primary().device):
|
||||
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
|
||||
else:
|
||||
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
|
||||
|
||||
def read_value(self):
|
||||
if enclosing_tpu_context() is None:
|
||||
|
@ -472,6 +472,12 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
# variable.
|
||||
self._var_policy = var_policy
|
||||
|
||||
@property
|
||||
def _devices(self):
|
||||
if self._packed_var is not None:
|
||||
return tuple(d for d in self._packed_var.devices)
|
||||
return tuple(v.device for v in self._values)
|
||||
|
||||
def is_initialized(self, name=None):
|
||||
"""Identifies if all the component variables are initialized.
|
||||
|
||||
@ -482,6 +488,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
The op that evaluates to True or False depending on if all the
|
||||
component variables are initialized.
|
||||
"""
|
||||
if self._packed_var is not None:
|
||||
return self._packed_var.is_initialized()
|
||||
result = self._primary.is_initialized()
|
||||
# We iterate through the list of values except the last one to allow us to
|
||||
# name the final `logical_and` op the same name that is passed by the user
|
||||
@ -552,6 +560,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
def aggregation(self):
|
||||
return self._aggregation
|
||||
|
||||
@property
|
||||
def _packed_variable(self):
|
||||
return self._packed_var
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
replica_id = values_util.get_current_replica_id_as_int()
|
||||
@ -559,6 +571,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
raise ValueError("`handle` is not available outside the replica context"
|
||||
" or a `tf.distribute.Strategy.update()` call.")
|
||||
else:
|
||||
if self._packed_var is not None:
|
||||
return self._packed_var.handle
|
||||
return self._values[replica_id].handle
|
||||
|
||||
def eval(self, session=None):
|
||||
@ -607,6 +621,33 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
def _in_graph_mode(self):
|
||||
return self._primary._in_graph_mode # pylint: disable=protected-access
|
||||
|
||||
def _get_replica(self, replica_id):
|
||||
"""Returns the value on a device with the given replica_id."""
|
||||
if self._packed_var is not None:
|
||||
return self._packed_var.on_device(self._devices[replica_id])
|
||||
return self._values[replica_id]
|
||||
|
||||
def _get(self):
|
||||
"""Returns the value for the current device or raises a ValueError."""
|
||||
replica_id = values_util.get_current_replica_id_as_int()
|
||||
if replica_id is None:
|
||||
return self._get_cross_replica()
|
||||
else:
|
||||
return self._get_replica(replica_id)
|
||||
|
||||
def _get_on_device_or_primary(self):
|
||||
"""Returns value in same replica or device if possible, else the _primary."""
|
||||
replica_id = values_util.get_current_replica_id_as_int()
|
||||
if replica_id is None:
|
||||
# Try to find a value on the current device.
|
||||
current_device = device_util.canonicalize(device_util.current())
|
||||
for i, value in enumerate(self._values):
|
||||
if device_util.canonicalize(value.device) == current_device:
|
||||
return self._get_replica(i)
|
||||
return self._get_replica(0)
|
||||
else:
|
||||
return self._get_replica(replica_id)
|
||||
|
||||
def read_value(self):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return array_ops.identity(self._get())
|
||||
@ -778,7 +819,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
if ds_context.in_cross_replica_context():
|
||||
update_replica_id = distribute_lib.get_update_replica_id()
|
||||
if update_replica_id is not None:
|
||||
return update_fn(self._values[update_replica_id], value, **kwargs)
|
||||
replica_value = self._get_replica(update_replica_id)
|
||||
return update_fn(replica_value, value, **kwargs)
|
||||
return self._update_cross_replica(update_fn, value, **kwargs)
|
||||
else:
|
||||
values_util.assert_replica_context(self.distribute_strategy)
|
||||
@ -802,6 +844,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
obj_map[v] = new_obj
|
||||
resource_map[v.handle] = new_obj.handle
|
||||
obj_map[self] = new_obj
|
||||
resource_map[self.handle] = new_obj.handle
|
||||
resource_map[self] = new_obj.handle
|
||||
return obj_map, resource_map
|
||||
|
||||
@ -835,6 +878,12 @@ class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
"""Restore the same value into all variables."""
|
||||
tensor, = restored_tensors
|
||||
packed_var = self._mirrored_variable._packed_variable # pylint: disable=protected-access
|
||||
if packed_var is not None:
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
values_util.assign_on_device(d, packed_var, tensor)
|
||||
for d in packed_var.devices))
|
||||
return control_flow_ops.group(
|
||||
tuple(
|
||||
values_util.assign_on_device(v.device, v, tensor)
|
||||
@ -1013,7 +1062,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
|
||||
def _get_cross_replica(self):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return self._primary
|
||||
return self._get_replica(0)
|
||||
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return self._distribute_strategy.reduce(
|
||||
|
@ -42,7 +42,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
@ -234,11 +233,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
# TODO(b/137795644): support CentralStroageStrategy
|
||||
# strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]
|
||||
))
|
||||
mode=["eager"]))
|
||||
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
@ -259,11 +258,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
# TODO(b/137795644): support CentralStroageStrategy
|
||||
# strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]
|
||||
))
|
||||
mode=["eager"]))
|
||||
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
@ -384,6 +383,16 @@ def _make_mirrored():
|
||||
return mirrored
|
||||
|
||||
|
||||
def mirrored_and_tpu_strategy_combinations():
|
||||
return combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["graph", "eager"])
|
||||
|
||||
|
||||
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _is_per_replica(self, result, expected, klass=values.PerReplica):
|
||||
@ -563,6 +572,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
synchronization=[
|
||||
@ -708,29 +718,40 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(distribution.run(assign)))
|
||||
|
||||
def testPackedVariable(self, distribution, synchronization, aggregation):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
mode=["eager"]))
|
||||
class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testPackedVariable(self, distribution):
|
||||
with distribution.scope():
|
||||
v0 = variables_lib.Variable(
|
||||
0., synchronization=synchronization, aggregation=aggregation)
|
||||
if not isinstance(v0, values.DistributedVariable):
|
||||
self.skipTest("This test doesn't apply to non DistributedVariables")
|
||||
|
||||
self.assertEqual(v0._packed_var, None)
|
||||
|
||||
device_type = device.DeviceSpec.from_string(v0._devices[0]).device_type
|
||||
for d in v0._devices:
|
||||
if device.DeviceSpec.from_string(d).device_type != device_type:
|
||||
self.skipTest("Packing variables on devices of different types "
|
||||
"is not supported yet.")
|
||||
v0 = variables_lib.Variable(0.)
|
||||
self.assertIsNone(v0._packed_var)
|
||||
|
||||
distribution._enable_packed_variable_in_eager_mode = True
|
||||
with distribution.scope():
|
||||
v1 = variables_lib.Variable(
|
||||
0., synchronization=synchronization, aggregation=aggregation)
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
v1 = variables_lib.Variable(0)
|
||||
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
|
||||
else:
|
||||
self.assertEqual(v1._packed_var, None)
|
||||
|
||||
devices = v1._devices
|
||||
for i in range(1, len(devices)):
|
||||
with distribute_lib.ReplicaContext(distribution, i):
|
||||
v1.assign(i)
|
||||
val = v1._get()
|
||||
self.assertIsInstance(val, packed.PackedVarAndDevice)
|
||||
self.assertEqual(val.device, devices[0])
|
||||
self.assertEqual(self.evaluate(val.read_value()), 0)
|
||||
for i in range(0, len(devices)):
|
||||
with distribute_lib.ReplicaContext(distribution, i):
|
||||
val = v1._get()
|
||||
self.assertIsInstance(val, packed.PackedVarAndDevice)
|
||||
self.assertEqual(val.device, devices[i])
|
||||
self.assertEqual(self.evaluate(val.read_value()), i)
|
||||
|
||||
|
||||
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
@ -920,6 +941,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
|
||||
@ -943,6 +965,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
def testValueInReplicaContext(self, distribution):
|
||||
@ -968,6 +991,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
def testAssignOutOfScope(self, distribution):
|
||||
@ -1041,6 +1065,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testInitializedToSameValueInsideEagerRun(self, distribution):
|
||||
@ -1066,6 +1091,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
def testAggregationOnlyFirstReplica(self, distribution):
|
||||
@ -1093,6 +1119,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testInitScope(self, distribution):
|
||||
@ -1143,13 +1170,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.experimental_local_results(distribution.run(add)))
|
||||
self.assertAllEqual([2, 2], per_replica_results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testAssignAdd(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
@ -1456,15 +1477,6 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
|
||||
self.assertEqual(2., self.evaluate(add1(replica_local)))
|
||||
|
||||
|
||||
def mirrored_and_tpu_strategy_combinations():
|
||||
return combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
mode=["graph", "eager"])
|
||||
|
||||
|
||||
# TODO(b/144432582): Add variable aggregation type to combinations to simplify
|
||||
# tests.
|
||||
def strategy_and_run_tf_function_combinations():
|
||||
@ -1478,6 +1490,7 @@ def strategy_and_run_tf_function_combinations():
|
||||
experimental_run_tf_function=[True, False]) + combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_packed_var,
|
||||
],
|
||||
mode=["graph", "eager"],
|
||||
experimental_run_tf_function=[True])
|
||||
|
@ -61,8 +61,14 @@ def on_write_assign_sub(var, value, use_locking=False, name=None,
|
||||
|
||||
|
||||
def assign_on_each_device(var, assign_func, value, read_value):
|
||||
update = control_flow_ops.group(
|
||||
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
|
||||
"""Update the variable on each replica with the given assign_func and value."""
|
||||
if var._packed_variable is not None: # pylint: disable=protected-access
|
||||
update = control_flow_ops.group(
|
||||
tuple(
|
||||
assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access
|
||||
else:
|
||||
update = control_flow_ops.group(
|
||||
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
|
||||
if not read_value:
|
||||
return update
|
||||
with ops.control_dependencies([update] if update else []):
|
||||
@ -104,7 +110,7 @@ def on_read_assign_cross_replica(var, value, read_value=True):
|
||||
# TODO(anjs): Should this be over all the replicas in sync since we
|
||||
# call `reduce` on the variable during read?
|
||||
if var.aggregation == vs.VariableAggregation.SUM:
|
||||
tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access
|
||||
tensor = math_ops.cast(tensor / len(var._devices), var.dtype) # pylint: disable=protected-access
|
||||
return assign_on_each_device(var, assign_on_device, tensor,
|
||||
read_value)
|
||||
|
||||
|
@ -298,7 +298,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
self._pivot = pivot
|
||||
self._replicated_vars = {}
|
||||
|
||||
def get_replicated_var_handle(self, name, vars_, is_mirrored=False):
|
||||
def get_replicated_var_handle(self, name, vars_, is_mirrored=False,
|
||||
is_packed=False):
|
||||
"""Returns a variable handle for replicated TPU variable 'var'.
|
||||
|
||||
This is a method used by an experimental replicated variable implementation
|
||||
@ -309,6 +310,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
vars_: The replicated TPU variables.
|
||||
is_mirrored: Whether the variables are mirrored, which guarantees the
|
||||
values in each replica are always the same.
|
||||
is_packed: Whether the replicated variables are packed into one variable.
|
||||
|
||||
Returns:
|
||||
The handle of the TPU replicated input node.
|
||||
@ -320,7 +322,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
if handle is not None:
|
||||
return handle
|
||||
|
||||
if device_assignment is not None:
|
||||
if device_assignment is not None and not is_packed:
|
||||
# Find a variable copy for each replica in the device assignment.
|
||||
# Note that the order of devices for replicas for the variable and the
|
||||
# device assignment might not match.
|
||||
@ -356,7 +358,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
graph._set_control_flow_context(self.outer_context)
|
||||
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
|
||||
name=name + "/handle",
|
||||
is_mirrored_variable=is_mirrored)
|
||||
is_mirrored_variable=is_mirrored,
|
||||
is_packed=is_packed)
|
||||
graph._set_control_flow_context(saved_context)
|
||||
# pylint: enable=protected-access
|
||||
self._replicated_vars[name] = handle
|
||||
|
Loading…
Reference in New Issue
Block a user