Support packed variable in DistributedVariable. Add an option to enable packed variable in TPUStrategy.
PiperOrigin-RevId: 317234665 Change-Id: I09e806cb8261815cd87a6d98817556dd8f7e8ed7
This commit is contained in:
parent
4d54ef3139
commit
7e6e549c46
@ -654,6 +654,7 @@ tpu_py_test(
|
|||||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||||
"//tensorflow/python/eager:remote",
|
"//tensorflow/python/eager:remote",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -787,6 +788,7 @@ py_library(
|
|||||||
name = "tpu_values",
|
name = "tpu_values",
|
||||||
srcs = ["tpu_values.py"],
|
srcs = ["tpu_values.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":packed_distributed_variable",
|
||||||
":values",
|
":values",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
@ -1602,7 +1604,7 @@ distribute_py_test(
|
|||||||
srcs = ["saved_model_save_load_test.py"],
|
srcs = ["saved_model_save_load_test.py"],
|
||||||
full_precision = True,
|
full_precision = True,
|
||||||
main = "saved_model_save_load_test.py",
|
main = "saved_model_save_load_test.py",
|
||||||
shard_count = 5,
|
shard_count = 7,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_rocm",
|
"no_rocm",
|
||||||
@ -1635,7 +1637,7 @@ distribute_py_test(
|
|||||||
srcs = ["saved_model_mixed_api_test.py"],
|
srcs = ["saved_model_mixed_api_test.py"],
|
||||||
full_precision = True,
|
full_precision = True,
|
||||||
main = "saved_model_mixed_api_test.py",
|
main = "saved_model_mixed_api_test.py",
|
||||||
shard_count = 5,
|
shard_count = 7,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_rocm",
|
"no_rocm",
|
||||||
|
@ -103,6 +103,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
|
|||||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
],
|
],
|
||||||
mode=["eager"]))
|
mode=["eager"]))
|
||||||
@ -138,6 +139,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
|
|||||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
],
|
],
|
||||||
mode=["eager"]))
|
mode=["eager"]))
|
||||||
|
@ -197,7 +197,8 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
|||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["eager"]))
|
mode=["eager"]))
|
||||||
def testNestedOutput(self, distribution):
|
def testNestedOutput(self, distribution):
|
||||||
@ -748,6 +749,10 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
|||||||
mode=["eager"]
|
mode=["eager"]
|
||||||
))
|
))
|
||||||
def testMultiDeviceDataCapturedFunction(self, distribution):
|
def testMultiDeviceDataCapturedFunction(self, distribution):
|
||||||
|
if getattr(distribution, "_enable_packed_variable_in_eager_mode", False):
|
||||||
|
self.skipTest(
|
||||||
|
"Dataset captured function doesn't support packed tensors yet "
|
||||||
|
"(b/145922293).")
|
||||||
inputs = constant_op.constant([2., 3.])
|
inputs = constant_op.constant([2., 3.])
|
||||||
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
|
dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
|
||||||
input_iterator = iter(
|
input_iterator = iter(
|
||||||
|
@ -148,6 +148,9 @@ def select_replica_mirrored(replica_id, structured):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expected value to be mirrored across replicas: %s in %s." %
|
"Expected value to be mirrored across replicas: %s in %s." %
|
||||||
(x, structured))
|
(x, structured))
|
||||||
|
packed_var = getattr(x, "_packed_variable", None)
|
||||||
|
if packed_var is not None:
|
||||||
|
return packed_var
|
||||||
return x.values[replica_id]
|
return x.values[replica_id]
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
@ -42,7 +42,7 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
|
|||||||
name: Optional name for the variable. Defaults to `'Variable'` and gets
|
name: Optional name for the variable. Defaults to `'Variable'` and gets
|
||||||
uniquified automatically.
|
uniquified automatically.
|
||||||
"""
|
"""
|
||||||
if not context.executing_eagerly():
|
if not ops.executing_eagerly_outside_functions():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"PackedDistributedVariable should be created in eager mode.")
|
"PackedDistributedVariable should be created in eager mode.")
|
||||||
if not distributed_variables:
|
if not distributed_variables:
|
||||||
@ -84,6 +84,9 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
|
|||||||
def devices(self):
|
def devices(self):
|
||||||
return self._devices
|
return self._devices
|
||||||
|
|
||||||
|
def on_device(self, device):
|
||||||
|
return PackedVarAndDevice(self, device)
|
||||||
|
|
||||||
def get_var_on_device(self, device):
|
def get_var_on_device(self, device):
|
||||||
for i, d in enumerate(self._devices):
|
for i, d in enumerate(self._devices):
|
||||||
if d == device:
|
if d == device:
|
||||||
@ -100,6 +103,9 @@ class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return self.get_var_on_current_device().handle
|
||||||
|
else:
|
||||||
return self._handle
|
return self._handle
|
||||||
|
|
||||||
def _read_variable_op(self):
|
def _read_variable_op(self):
|
||||||
@ -269,6 +275,7 @@ class PackedVarAndDevice(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
|
with ops.device(self._device):
|
||||||
return self._var.handle
|
return self._var.handle
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -46,7 +46,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
|||||||
v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')
|
v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')
|
||||||
|
|
||||||
packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
|
packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
|
||||||
self.assertTrue(packed_var.handle.is_packed)
|
self.assertFalse(packed_var.handle.is_packed)
|
||||||
self.assertTrue(packed_var.is_initialized)
|
self.assertTrue(packed_var.is_initialized)
|
||||||
|
|
||||||
with ops.device('/cpu:0'):
|
with ops.device('/cpu:0'):
|
||||||
@ -61,6 +61,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
|||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def update_var():
|
def update_var():
|
||||||
|
self.assertTrue(packed_var.handle.is_packed)
|
||||||
with ops.device('/cpu:0'):
|
with ops.device('/cpu:0'):
|
||||||
packed_var.assign_add(3.0).assign_sub(1.0)
|
packed_var.assign_add(3.0).assign_sub(1.0)
|
||||||
read0 = packed_var.value()
|
read0 = packed_var.value()
|
||||||
@ -85,7 +86,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
|||||||
|
|
||||||
packed_var0 = packed_distributed_variable.PackedVarAndDevice(
|
packed_var0 = packed_distributed_variable.PackedVarAndDevice(
|
||||||
packed_var, device0)
|
packed_var, device0)
|
||||||
self.assertTrue(packed_var0.handle.is_packed)
|
self.assertFalse(packed_var0.handle.is_packed)
|
||||||
self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)
|
self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)
|
||||||
|
|
||||||
packed_var1 = packed_distributed_variable.PackedVarAndDevice(
|
packed_var1 = packed_distributed_variable.PackedVarAndDevice(
|
||||||
@ -94,6 +95,7 @@ class PackedDistributedVariableTest(test.TestCase):
|
|||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def func():
|
def func():
|
||||||
|
self.assertTrue(packed_var.handle.is_packed)
|
||||||
var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0)
|
var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0)
|
||||||
var0.assign_add(3.0)
|
var0.assign_add(3.0)
|
||||||
var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1)
|
var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1)
|
||||||
|
@ -58,6 +58,7 @@ strategies = [
|
|||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -53,7 +53,11 @@ _did_connect_to_cluster = False
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
|
def _get_tpu_strategy_creator(steps_per_run,
|
||||||
|
use_single_core=False,
|
||||||
|
enable_packed_variable=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
def _create_tpu_strategy():
|
def _create_tpu_strategy():
|
||||||
global _did_connect_to_cluster
|
global _did_connect_to_cluster
|
||||||
|
|
||||||
@ -87,10 +91,13 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
|
|||||||
|
|
||||||
# Steps per run is only supported in TF 1.x
|
# Steps per run is only supported in TF 1.x
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
return tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
|
strategy = tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs)
|
||||||
else:
|
else:
|
||||||
return tpu_lib.TPUStrategyV1(resolver, steps_per_run,
|
strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
|
||||||
device_assignment, **kwargs)
|
device_assignment, **kwargs)
|
||||||
|
strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access
|
||||||
|
return strategy
|
||||||
|
|
||||||
return _create_tpu_strategy
|
return _create_tpu_strategy
|
||||||
|
|
||||||
|
|
||||||
@ -117,6 +124,10 @@ one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
|
|||||||
required_gpus=1)
|
required_gpus=1)
|
||||||
tpu_strategy = combinations.NamedDistribution(
|
tpu_strategy = combinations.NamedDistribution(
|
||||||
"TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
|
"TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
|
||||||
|
tpu_strategy_packed_var = combinations.NamedDistribution(
|
||||||
|
"TPUPackedVar",
|
||||||
|
_get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True),
|
||||||
|
required_tpu=True)
|
||||||
tpu_strategy_one_step = combinations.NamedDistribution(
|
tpu_strategy_one_step = combinations.NamedDistribution(
|
||||||
"TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
|
"TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
|
||||||
tpu_strategy_one_core = combinations.NamedDistribution(
|
tpu_strategy_one_core = combinations.NamedDistribution(
|
||||||
@ -286,6 +297,7 @@ strategies_minus_default_and_tpu = [
|
|||||||
tpu_strategies = [
|
tpu_strategies = [
|
||||||
tpu_strategy, # steps_per_run=2
|
tpu_strategy, # steps_per_run=2
|
||||||
tpu_strategy_one_step,
|
tpu_strategy_one_step,
|
||||||
|
tpu_strategy_packed_var,
|
||||||
cloud_tpu_strategy,
|
cloud_tpu_strategy,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -141,6 +141,10 @@ class TPUStrategy(distribute_lib.Strategy):
|
|||||||
"num_workers").set(self.extended.num_hosts)
|
"num_workers").set(self.extended.num_hosts)
|
||||||
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
||||||
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
|
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
|
||||||
|
# Packed variable is used to reduce the overhead of function execution.
|
||||||
|
# For a DistributedVariable, only one variable handle is captured into a
|
||||||
|
# function graph. It's only supported in eager mode.
|
||||||
|
self._enable_packed_variable_in_eager_mode = False
|
||||||
|
|
||||||
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||||
# can use the default implementation.
|
# can use the default implementation.
|
||||||
@ -185,6 +189,10 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
|
|||||||
"num_workers").set(self.extended.num_hosts)
|
"num_workers").set(self.extended.num_hosts)
|
||||||
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
||||||
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
|
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
|
||||||
|
# Packed variable is used to reduce the overhead of function execution.
|
||||||
|
# For a DistributedVariable, only one variable handle is captured into a
|
||||||
|
# function graph. It's only supported in eager mode.
|
||||||
|
self._enable_packed_variable_in_eager_mode = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def steps_per_run(self):
|
def steps_per_run(self):
|
||||||
@ -671,20 +679,29 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||||
reduce_op, value, destinations, self._num_replicas_in_sync)
|
reduce_op, value, destinations, self._num_replicas_in_sync)
|
||||||
|
|
||||||
|
value_list = value.values
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if isinstance(
|
||||||
|
value,
|
||||||
|
values.DistributedVariable) and value._packed_variable is not None:
|
||||||
|
value_list = tuple(
|
||||||
|
value._packed_variable.on_device(d)
|
||||||
|
for d in value._packed_variable.devices)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
# Currently XLA op by op mode has a limit for the number of inputs for a
|
# Currently XLA op by op mode has a limit for the number of inputs for a
|
||||||
# single op, thus we break one `add_n` op into a group of `add_n` ops to
|
# single op, thus we break one `add_n` op into a group of `add_n` ops to
|
||||||
# work around the constraint.
|
# work around the constraint.
|
||||||
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
|
# TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
|
||||||
if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
|
if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
|
||||||
output = math_ops.add_n(value.values)
|
output = math_ops.add_n(value_list)
|
||||||
else:
|
else:
|
||||||
output = array_ops.zeros_like(
|
output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
|
||||||
value.values[0], dtype=value.values[0].dtype)
|
for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
|
||||||
for i in range(0, len(value.values), _XLA_OP_BY_OP_INPUTS_LIMIT):
|
output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
|
||||||
output += math_ops.add_n(value.values[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
|
|
||||||
|
|
||||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||||
output *= (1. / len(value.values))
|
output *= (1. / len(value_list))
|
||||||
|
|
||||||
devices = cross_device_ops_lib.get_devices_from(destinations)
|
devices = cross_device_ops_lib.get_devices_from(destinations)
|
||||||
|
|
||||||
@ -710,17 +727,28 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
else:
|
else:
|
||||||
return (fn(var, *args, **kwargs),)
|
return (fn(var, *args, **kwargs),)
|
||||||
|
|
||||||
# Otherwise, we revert to MirroredStrategy behavior and update each variable
|
# Otherwise, we revert to MirroredStrategy behavior and update the variable
|
||||||
# directly.
|
# on each replica directly.
|
||||||
updates = []
|
updates = []
|
||||||
for i, v in enumerate(var.values):
|
values_and_devices = []
|
||||||
|
packed_var = var._packed_variable # pylint: disable=protected-access
|
||||||
|
if packed_var is not None:
|
||||||
|
for device in packed_var.devices:
|
||||||
|
values_and_devices.append((packed_var, device))
|
||||||
|
else:
|
||||||
|
for value in var.values:
|
||||||
|
values_and_devices.append((value, value.device))
|
||||||
|
|
||||||
|
for i, value_and_device in enumerate(values_and_devices):
|
||||||
|
value = value_and_device[0]
|
||||||
|
device = value_and_device[1]
|
||||||
name = "update_%d" % i
|
name = "update_%d" % i
|
||||||
with ops.device(v.device), \
|
with ops.device(device), \
|
||||||
distribute_lib.UpdateContext(i), \
|
distribute_lib.UpdateContext(i), \
|
||||||
ops.name_scope(name):
|
ops.name_scope(name):
|
||||||
# If args and kwargs are not mirrored, the value is returned as is.
|
# If args and kwargs are not mirrored, the value is returned as is.
|
||||||
updates.append(
|
updates.append(
|
||||||
fn(v, *distribute_utils.select_replica_mirrored(i, args),
|
fn(value, *distribute_utils.select_replica_mirrored(i, args),
|
||||||
**distribute_utils.select_replica_mirrored(i, kwargs)))
|
**distribute_utils.select_replica_mirrored(i, kwargs)))
|
||||||
return distribute_utils.update_regroup(self, updates, group)
|
return distribute_utils.update_regroup(self, updates, group)
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
@ -64,14 +66,17 @@ def get_tpu_cluster_resolver():
|
|||||||
return resolver
|
return resolver
|
||||||
|
|
||||||
|
|
||||||
def get_tpu_strategy():
|
def get_tpu_strategy(enable_packed_var=False):
|
||||||
resolver = get_tpu_cluster_resolver()
|
resolver = get_tpu_cluster_resolver()
|
||||||
remote.connect_to_cluster(resolver)
|
remote.connect_to_cluster(resolver)
|
||||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
return tpu_lib.TPUStrategy(resolver)
|
strategy = tpu_lib.TPUStrategy(resolver)
|
||||||
|
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
class TPUStrategyTest(test.TestCase):
|
# TPU tests which don't use TPUStrategy.
|
||||||
|
class TPUTest(test.TestCase):
|
||||||
|
|
||||||
def test_multiple_initialize_system(self):
|
def test_multiple_initialize_system(self):
|
||||||
resolver = get_tpu_cluster_resolver()
|
resolver = get_tpu_cluster_resolver()
|
||||||
@ -82,177 +87,6 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
self.assertRegex(str(mock_log.call_args), "already been initialized")
|
self.assertRegex(str(mock_log.call_args), "already been initialized")
|
||||||
|
|
||||||
def test_sequential_experimental_runs(self):
|
|
||||||
resolver = get_tpu_cluster_resolver()
|
|
||||||
remote.connect_to_cluster(resolver)
|
|
||||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
|
||||||
# Computation replicated to all cores.
|
|
||||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
|
||||||
topology, num_replicas=2)
|
|
||||||
strategy = tpu_lib.TPUStrategy(
|
|
||||||
resolver, device_assignment=device_assignment)
|
|
||||||
|
|
||||||
# Computation on the 1st core.
|
|
||||||
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
|
|
||||||
topology, num_replicas=1)
|
|
||||||
strategy2 = tpu_lib.TPUStrategy(
|
|
||||||
resolver, device_assignment=device_assignment2)
|
|
||||||
|
|
||||||
def computation(x):
|
|
||||||
return math_ops.square(x)
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def train_step():
|
|
||||||
outputs = strategy.experimental_local_results(
|
|
||||||
strategy.run(computation, args=([2., 2.],)))
|
|
||||||
outputs2 = strategy2.run(
|
|
||||||
computation, args=([outputs[0]],))
|
|
||||||
return outputs2
|
|
||||||
|
|
||||||
self.assertAllEqual([[16., 16.]], train_step())
|
|
||||||
|
|
||||||
def test_device_switch_case(self):
|
|
||||||
strategy = get_tpu_strategy()
|
|
||||||
with strategy.scope():
|
|
||||||
a = variables.Variable(1)
|
|
||||||
|
|
||||||
inference_iteration = variables.Variable(-1)
|
|
||||||
|
|
||||||
def inference_fn(x, i):
|
|
||||||
return a + x + i
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def run_inference(x):
|
|
||||||
|
|
||||||
def do_inference(device, inference_fn, i):
|
|
||||||
with ops.device(device):
|
|
||||||
return inference_fn(x, i)
|
|
||||||
|
|
||||||
branch_fns = {
|
|
||||||
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
|
|
||||||
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
|
|
||||||
}
|
|
||||||
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
|
|
||||||
return control_flow_ops.switch_case(branch_index, branch_fns)
|
|
||||||
|
|
||||||
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
|
|
||||||
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
|
|
||||||
|
|
||||||
def test_recover_from_compilation_failures(self):
|
|
||||||
# TODO(b/148150981): Stop skipping this test once recovery works
|
|
||||||
# for non-local TPU.
|
|
||||||
if FLAGS.tpu:
|
|
||||||
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
|
|
||||||
|
|
||||||
# Disable automatic outside compilation.
|
|
||||||
config.set_soft_device_placement(False)
|
|
||||||
strategy = get_tpu_strategy()
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def compilation_failure_run():
|
|
||||||
|
|
||||||
def computation():
|
|
||||||
return random_ops.random_gamma([10], [0.5, 1.5])
|
|
||||||
|
|
||||||
return strategy.run(computation)
|
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
|
||||||
"TPU compilation failed"):
|
|
||||||
compilation_failure_run()
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def good_run():
|
|
||||||
|
|
||||||
def computation():
|
|
||||||
return random_ops.random_normal([10])
|
|
||||||
|
|
||||||
return strategy.run(computation)
|
|
||||||
|
|
||||||
good_run()
|
|
||||||
|
|
||||||
def test_dynamic_shape_with_outside_compilation_failure(self):
|
|
||||||
# Enable automatic outside compilation.
|
|
||||||
config.set_soft_device_placement(True)
|
|
||||||
strategy = get_tpu_strategy()
|
|
||||||
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
|
|
||||||
2, drop_remainder=False)
|
|
||||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
|
||||||
iterator = iter(dataset)
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def train_fn(iterator):
|
|
||||||
|
|
||||||
def step_fn(inputs):
|
|
||||||
_, inputs = inputs
|
|
||||||
return math_ops.reduce_sum(inputs)
|
|
||||||
|
|
||||||
return strategy.experimental_local_results(
|
|
||||||
strategy.run(step_fn, args=(next(iterator),)))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
|
|
||||||
logging.info(train_fn(iterator))
|
|
||||||
|
|
||||||
def test_computation_on_subset_cores(self):
|
|
||||||
resolver = get_tpu_cluster_resolver()
|
|
||||||
remote.connect_to_cluster(resolver)
|
|
||||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
|
||||||
all_core_strategy = tpu_lib.TPUStrategy(resolver)
|
|
||||||
|
|
||||||
with all_core_strategy.scope():
|
|
||||||
v = variables.Variable(0.0,
|
|
||||||
aggregation=variables.VariableAggregation.MEAN)
|
|
||||||
|
|
||||||
# Computation on the 1st core.
|
|
||||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
|
||||||
topology, num_replicas=1)
|
|
||||||
first_core_strategy = tpu_lib.TPUStrategy(
|
|
||||||
resolver, device_assignment=device_assignment)
|
|
||||||
|
|
||||||
# Computation on the 2nd core.
|
|
||||||
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
|
||||||
topology, [[[0, 0, 0, 1]]])
|
|
||||||
second_core_strategy = tpu_lib.TPUStrategy(
|
|
||||||
resolver, device_assignment=device_assignment2)
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def train_step():
|
|
||||||
|
|
||||||
def step_fn():
|
|
||||||
return v + 1.0
|
|
||||||
|
|
||||||
all_core_strategy.run(step_fn)
|
|
||||||
r1 = first_core_strategy.run(step_fn)
|
|
||||||
r2 = second_core_strategy.run(step_fn)
|
|
||||||
return r1 + r2
|
|
||||||
|
|
||||||
train_step()
|
|
||||||
self.assertAllEqual(2., train_step())
|
|
||||||
|
|
||||||
def test_worker_devices_on_subset_cores(self):
|
|
||||||
resolver = get_tpu_cluster_resolver()
|
|
||||||
remote.connect_to_cluster(resolver)
|
|
||||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
|
||||||
|
|
||||||
# Strategy for the 1st core.
|
|
||||||
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
|
||||||
topology, num_replicas=1)
|
|
||||||
first_core_strategy = tpu_lib.TPUStrategy(
|
|
||||||
resolver, device_assignment=device_assignment)
|
|
||||||
|
|
||||||
# Strategy for the 2nd core.
|
|
||||||
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
|
||||||
topology, [[[0, 0, 0, 1]]])
|
|
||||||
second_core_strategy = tpu_lib.TPUStrategy(
|
|
||||||
resolver, device_assignment=device_assignment2)
|
|
||||||
|
|
||||||
self.assertLen(first_core_strategy.extended.worker_devices, 1)
|
|
||||||
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
|
|
||||||
"device:TPU:0")
|
|
||||||
|
|
||||||
self.assertLen(second_core_strategy.extended.worker_devices, 1)
|
|
||||||
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
|
|
||||||
"device:TPU:1")
|
|
||||||
|
|
||||||
def test_tpu_tf_function_same_device(self):
|
def test_tpu_tf_function_same_device(self):
|
||||||
with ops.device("/device:TPU:0"):
|
with ops.device("/device:TPU:0"):
|
||||||
a = variables.Variable(1)
|
a = variables.Variable(1)
|
||||||
@ -288,8 +122,194 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
result = bar() + 1
|
result = bar() + 1
|
||||||
self.assertAllEqual(result, 2)
|
self.assertAllEqual(result, 2)
|
||||||
|
|
||||||
def test_control_output_in_while_body_fn(self):
|
|
||||||
strategy = get_tpu_strategy()
|
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
||||||
|
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_sequential_experimental_runs(self, enable_packed_var):
|
||||||
|
resolver = get_tpu_cluster_resolver()
|
||||||
|
remote.connect_to_cluster(resolver)
|
||||||
|
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
|
# Computation replicated to all cores.
|
||||||
|
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||||
|
topology, num_replicas=2)
|
||||||
|
strategy = tpu_lib.TPUStrategy(
|
||||||
|
resolver, device_assignment=device_assignment)
|
||||||
|
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
|
||||||
|
|
||||||
|
# Computation on the 1st core.
|
||||||
|
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
|
||||||
|
topology, num_replicas=1)
|
||||||
|
strategy2 = tpu_lib.TPUStrategy(
|
||||||
|
resolver, device_assignment=device_assignment2)
|
||||||
|
|
||||||
|
def computation(x):
|
||||||
|
return math_ops.square(x)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_step():
|
||||||
|
outputs = strategy.experimental_local_results(
|
||||||
|
strategy.run(computation, args=([2., 2.],)))
|
||||||
|
outputs2 = strategy2.run(
|
||||||
|
computation, args=([outputs[0]],))
|
||||||
|
return outputs2
|
||||||
|
|
||||||
|
self.assertAllEqual([[16., 16.]], train_step())
|
||||||
|
|
||||||
|
def test_device_switch_case(self, enable_packed_var):
|
||||||
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
with strategy.scope():
|
||||||
|
a = variables.Variable(1)
|
||||||
|
|
||||||
|
inference_iteration = variables.Variable(-1)
|
||||||
|
|
||||||
|
def inference_fn(x, i):
|
||||||
|
return a + x + i
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def run_inference(x):
|
||||||
|
|
||||||
|
def do_inference(device, inference_fn, i):
|
||||||
|
with ops.device(device):
|
||||||
|
return inference_fn(x, i)
|
||||||
|
|
||||||
|
branch_fns = {
|
||||||
|
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
|
||||||
|
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
|
||||||
|
}
|
||||||
|
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
|
||||||
|
return control_flow_ops.switch_case(branch_index, branch_fns)
|
||||||
|
|
||||||
|
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
|
||||||
|
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
|
||||||
|
|
||||||
|
def test_recover_from_compilation_failures(self, enable_packed_var):
|
||||||
|
# TODO(b/148150981): Stop skipping this test once recovery works
|
||||||
|
# for non-local TPU.
|
||||||
|
if FLAGS.tpu:
|
||||||
|
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
|
||||||
|
|
||||||
|
# Disable automatic outside compilation.
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def compilation_failure_run():
|
||||||
|
|
||||||
|
def computation():
|
||||||
|
return random_ops.random_gamma([10], [0.5, 1.5])
|
||||||
|
|
||||||
|
return strategy.run(computation)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"TPU compilation failed"):
|
||||||
|
compilation_failure_run()
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def good_run():
|
||||||
|
|
||||||
|
def computation():
|
||||||
|
return random_ops.random_normal([10])
|
||||||
|
|
||||||
|
return strategy.run(computation)
|
||||||
|
|
||||||
|
good_run()
|
||||||
|
|
||||||
|
def test_dynamic_shape_with_outside_compilation_failure(
|
||||||
|
self, enable_packed_var):
|
||||||
|
# Enable automatic outside compilation.
|
||||||
|
config.set_soft_device_placement(True)
|
||||||
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
|
||||||
|
2, drop_remainder=False)
|
||||||
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||||
|
iterator = iter(dataset)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_fn(iterator):
|
||||||
|
|
||||||
|
def step_fn(inputs):
|
||||||
|
_, inputs = inputs
|
||||||
|
return math_ops.reduce_sum(inputs)
|
||||||
|
|
||||||
|
return strategy.experimental_local_results(
|
||||||
|
strategy.run(step_fn, args=(next(iterator),)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
|
||||||
|
logging.info(train_fn(iterator))
|
||||||
|
|
||||||
|
def test_computation_on_subset_cores(self, enable_packed_var):
|
||||||
|
resolver = get_tpu_cluster_resolver()
|
||||||
|
remote.connect_to_cluster(resolver)
|
||||||
|
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
|
all_core_strategy = tpu_lib.TPUStrategy(resolver)
|
||||||
|
all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
|
||||||
|
|
||||||
|
with all_core_strategy.scope():
|
||||||
|
v = variables.Variable(0.0,
|
||||||
|
aggregation=variables.VariableAggregation.MEAN)
|
||||||
|
|
||||||
|
# Computation on the 1st core.
|
||||||
|
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||||
|
topology, num_replicas=1)
|
||||||
|
first_core_strategy = tpu_lib.TPUStrategy(
|
||||||
|
resolver, device_assignment=device_assignment)
|
||||||
|
first_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||||
|
enable_packed_var)
|
||||||
|
|
||||||
|
# Computation on the 2nd core.
|
||||||
|
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
||||||
|
topology, [[[0, 0, 0, 1]]])
|
||||||
|
second_core_strategy = tpu_lib.TPUStrategy(
|
||||||
|
resolver, device_assignment=device_assignment2)
|
||||||
|
second_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||||
|
enable_packed_var)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_step():
|
||||||
|
|
||||||
|
def step_fn():
|
||||||
|
return v + 1.0
|
||||||
|
|
||||||
|
all_core_strategy.run(step_fn)
|
||||||
|
r1 = first_core_strategy.run(step_fn)
|
||||||
|
r2 = second_core_strategy.run(step_fn)
|
||||||
|
return r1 + r2
|
||||||
|
|
||||||
|
train_step()
|
||||||
|
self.assertAllEqual(2., train_step())
|
||||||
|
|
||||||
|
def test_worker_devices_on_subset_cores(self, enable_packed_var):
|
||||||
|
resolver = get_tpu_cluster_resolver()
|
||||||
|
remote.connect_to_cluster(resolver)
|
||||||
|
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
|
|
||||||
|
# Strategy for the 1st core.
|
||||||
|
device_assignment = device_assignment_lib.DeviceAssignment.build(
|
||||||
|
topology, num_replicas=1)
|
||||||
|
first_core_strategy = tpu_lib.TPUStrategy(
|
||||||
|
resolver, device_assignment=device_assignment)
|
||||||
|
first_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||||
|
enable_packed_var)
|
||||||
|
|
||||||
|
# Strategy for the 2nd core.
|
||||||
|
device_assignment2 = device_assignment_lib.DeviceAssignment(
|
||||||
|
topology, [[[0, 0, 0, 1]]])
|
||||||
|
second_core_strategy = tpu_lib.TPUStrategy(
|
||||||
|
resolver, device_assignment=device_assignment2)
|
||||||
|
second_core_strategy._enable_packed_variable_in_eager_mode = (
|
||||||
|
enable_packed_var)
|
||||||
|
|
||||||
|
self.assertLen(first_core_strategy.extended.worker_devices, 1)
|
||||||
|
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
|
||||||
|
"device:TPU:0")
|
||||||
|
|
||||||
|
self.assertLen(second_core_strategy.extended.worker_devices, 1)
|
||||||
|
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
|
||||||
|
"device:TPU:1")
|
||||||
|
|
||||||
|
def test_control_output_in_while_body_fn(self, enable_packed_var):
|
||||||
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
v = variables.Variable(
|
v = variables.Variable(
|
||||||
@ -307,8 +327,8 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
train_step()
|
train_step()
|
||||||
self.assertEqual(2.0, v.numpy())
|
self.assertEqual(2.0, v.numpy())
|
||||||
|
|
||||||
def test_cluster_in_graph_and_while_body_fn(self):
|
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def train_step():
|
def train_step():
|
||||||
@ -328,8 +348,8 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
sum_val = train_step().numpy().astype(float)
|
sum_val = train_step().numpy().astype(float)
|
||||||
self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
|
self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
|
||||||
|
|
||||||
def test_two_clusters_with_same_fn(self):
|
def test_two_clusters_with_same_fn(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def foo(x):
|
def foo(x):
|
||||||
@ -342,8 +362,8 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
|
|
||||||
bar(1)
|
bar(1)
|
||||||
|
|
||||||
def test_using_external_variable_inside_tf_function(self):
|
def test_using_external_variable_inside_tf_function(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
dataset = dataset_ops.Dataset.range(
|
dataset = dataset_ops.Dataset.range(
|
||||||
strategy.num_replicas_in_sync * 2,
|
strategy.num_replicas_in_sync * 2,
|
||||||
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
|
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
|
||||||
@ -364,8 +384,8 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
strategy.experimental_local_results(train_step(next(input_iterator))))
|
strategy.experimental_local_results(train_step(next(input_iterator))))
|
||||||
|
|
||||||
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
|
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
|
||||||
def test_all_reduce_on_sync_on_read_variable(self):
|
def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
dataset = dataset_ops.Dataset.range(
|
dataset = dataset_ops.Dataset.range(
|
||||||
strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
|
strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
|
||||||
strategy.num_replicas_in_sync, drop_remainder=True)
|
strategy.num_replicas_in_sync, drop_remainder=True)
|
||||||
@ -404,8 +424,8 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
self.assertAllEqual((0.,), w.read_value())
|
self.assertAllEqual((0.,), w.read_value())
|
||||||
|
|
||||||
# TODO(b/140633529): Re-enable the test.
|
# TODO(b/140633529): Re-enable the test.
|
||||||
def disable_test_experimental_run_output_on_device(self):
|
def disable_test_experimental_run_output_on_device(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
def computation(x):
|
def computation(x):
|
||||||
return math_ops.square(x)
|
return math_ops.square(x)
|
||||||
@ -423,8 +443,8 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
|
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
|
||||||
results[1].backing_device)
|
results[1].backing_device)
|
||||||
|
|
||||||
def test_composite_input(self):
|
def test_composite_input(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
if strategy.num_replicas_in_sync != 2:
|
if strategy.num_replicas_in_sync != 2:
|
||||||
self.skipTest("Test assumes two replicas.")
|
self.skipTest("Test assumes two replicas.")
|
||||||
|
|
||||||
@ -463,8 +483,9 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
self.assertAllEqual(result,
|
self.assertAllEqual(result,
|
||||||
[[[0.0, 1.0], [3.0, 8.0]], [[0.0, 1.0], [3.0, 8.0]]])
|
[[[0.0, 1.0], [3.0, 8.0]], [[0.0, 1.0], [3.0, 8.0]]])
|
||||||
|
|
||||||
def test_composite_input_dynamic_shapes_outside_compilation(self):
|
def test_composite_input_dynamic_shapes_outside_compilation(
|
||||||
strategy = get_tpu_strategy()
|
self, enable_packed_var):
|
||||||
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
if strategy.num_replicas_in_sync != 2:
|
if strategy.num_replicas_in_sync != 2:
|
||||||
self.skipTest("Test assumes two replicas.")
|
self.skipTest("Test assumes two replicas.")
|
||||||
|
|
||||||
@ -506,11 +527,11 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
result = sparse_lookup(dataset)
|
result = sparse_lookup(dataset)
|
||||||
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
|
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
|
||||||
|
|
||||||
def test_per_device_tracing_of_mirrored_variables(self):
|
def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var):
|
||||||
# Define trace_count as a list to avoid python scoping error
|
# Define trace_count as a list to avoid python scoping error
|
||||||
trace_count = [0]
|
trace_count = [0]
|
||||||
|
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
variable = variables.Variable(0.0)
|
variable = variables.Variable(0.0)
|
||||||
|
|
||||||
@ -527,7 +548,10 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
update_variable.get_concrete_function()
|
update_variable.get_concrete_function()
|
||||||
self.assertEqual(trace_count[0], len(strategy.extended.worker_devices))
|
self.assertLen(strategy.extended.worker_devices, trace_count[0])
|
||||||
|
|
||||||
|
|
||||||
|
class TPUStrategyDataPrefetchTest(test.TestCase):
|
||||||
|
|
||||||
def test_prefetch_to_device_default(self):
|
def test_prefetch_to_device_default(self):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy()
|
||||||
|
@ -24,6 +24,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
@ -46,15 +47,27 @@ def _maybe_enter_graph(tensor):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _maybe_on_device(var):
|
||||||
|
# Add a device scope for packed variables.
|
||||||
|
if isinstance(var, packed.PackedVarAndDevice):
|
||||||
|
with ops.device(var.device):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring
|
||||||
|
|
||||||
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
|
def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring
|
||||||
del use_locking # Unused.
|
del use_locking # Unused.
|
||||||
|
|
||||||
with _maybe_enter_graph(var.handle):
|
handle = var.handle
|
||||||
|
with _maybe_enter_graph(handle), _maybe_on_device(var):
|
||||||
op = raw_assign_fn(
|
op = raw_assign_fn(
|
||||||
var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
|
handle,
|
||||||
|
ops.convert_to_tensor(value, dtype=var.dtype),
|
||||||
|
name=name)
|
||||||
with ops.control_dependencies([op]):
|
with ops.control_dependencies([op]):
|
||||||
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
return var._read_variable_op() if read_value else op # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -97,23 +110,37 @@ class TPUVariableMixin(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
|
"""The handle by which this variable can be accessed."""
|
||||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||||
tpu_context = enclosing_tpu_context()
|
tpu_context = enclosing_tpu_context()
|
||||||
if tpu_context is None or context.executing_eagerly():
|
if tpu_context is None or context.executing_eagerly():
|
||||||
return self._get_on_device_or_primary().handle
|
return self._get_on_device_or_primary().handle
|
||||||
else:
|
else:
|
||||||
return tpu_context.get_replicated_var_handle(self._handle_id,
|
is_packed = self._packed_var is not None
|
||||||
self._values,
|
val = self._values
|
||||||
self._is_mirrored())
|
if is_packed:
|
||||||
|
val = [self._packed_var]
|
||||||
|
|
||||||
|
return tpu_context.get_replicated_var_handle(self._handle_id, val,
|
||||||
|
self._is_mirrored(),
|
||||||
|
is_packed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self.handle.device
|
return self.handle.device
|
||||||
|
|
||||||
def _read_variable_op(self):
|
def _read_variable_op(self):
|
||||||
|
"""Reads the value of this variable."""
|
||||||
if self.trainable:
|
if self.trainable:
|
||||||
tape.variable_accessed(self)
|
tape.variable_accessed(self)
|
||||||
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
|
|
||||||
|
handle = self.handle
|
||||||
|
if getattr(handle, "is_packed", False):
|
||||||
|
# Add a device scope for a packed variable handle.
|
||||||
|
with ops.device(self._get_on_device_or_primary().device):
|
||||||
|
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
|
||||||
|
else:
|
||||||
|
return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
|
||||||
|
|
||||||
def read_value(self):
|
def read_value(self):
|
||||||
if enclosing_tpu_context() is None:
|
if enclosing_tpu_context() is None:
|
||||||
|
@ -472,6 +472,12 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
# variable.
|
# variable.
|
||||||
self._var_policy = var_policy
|
self._var_policy = var_policy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _devices(self):
|
||||||
|
if self._packed_var is not None:
|
||||||
|
return tuple(d for d in self._packed_var.devices)
|
||||||
|
return tuple(v.device for v in self._values)
|
||||||
|
|
||||||
def is_initialized(self, name=None):
|
def is_initialized(self, name=None):
|
||||||
"""Identifies if all the component variables are initialized.
|
"""Identifies if all the component variables are initialized.
|
||||||
|
|
||||||
@ -482,6 +488,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
The op that evaluates to True or False depending on if all the
|
The op that evaluates to True or False depending on if all the
|
||||||
component variables are initialized.
|
component variables are initialized.
|
||||||
"""
|
"""
|
||||||
|
if self._packed_var is not None:
|
||||||
|
return self._packed_var.is_initialized()
|
||||||
result = self._primary.is_initialized()
|
result = self._primary.is_initialized()
|
||||||
# We iterate through the list of values except the last one to allow us to
|
# We iterate through the list of values except the last one to allow us to
|
||||||
# name the final `logical_and` op the same name that is passed by the user
|
# name the final `logical_and` op the same name that is passed by the user
|
||||||
@ -552,6 +560,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
def aggregation(self):
|
def aggregation(self):
|
||||||
return self._aggregation
|
return self._aggregation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _packed_variable(self):
|
||||||
|
return self._packed_var
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
replica_id = values_util.get_current_replica_id_as_int()
|
replica_id = values_util.get_current_replica_id_as_int()
|
||||||
@ -559,6 +571,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
raise ValueError("`handle` is not available outside the replica context"
|
raise ValueError("`handle` is not available outside the replica context"
|
||||||
" or a `tf.distribute.Strategy.update()` call.")
|
" or a `tf.distribute.Strategy.update()` call.")
|
||||||
else:
|
else:
|
||||||
|
if self._packed_var is not None:
|
||||||
|
return self._packed_var.handle
|
||||||
return self._values[replica_id].handle
|
return self._values[replica_id].handle
|
||||||
|
|
||||||
def eval(self, session=None):
|
def eval(self, session=None):
|
||||||
@ -607,6 +621,33 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
def _in_graph_mode(self):
|
def _in_graph_mode(self):
|
||||||
return self._primary._in_graph_mode # pylint: disable=protected-access
|
return self._primary._in_graph_mode # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def _get_replica(self, replica_id):
|
||||||
|
"""Returns the value on a device with the given replica_id."""
|
||||||
|
if self._packed_var is not None:
|
||||||
|
return self._packed_var.on_device(self._devices[replica_id])
|
||||||
|
return self._values[replica_id]
|
||||||
|
|
||||||
|
def _get(self):
|
||||||
|
"""Returns the value for the current device or raises a ValueError."""
|
||||||
|
replica_id = values_util.get_current_replica_id_as_int()
|
||||||
|
if replica_id is None:
|
||||||
|
return self._get_cross_replica()
|
||||||
|
else:
|
||||||
|
return self._get_replica(replica_id)
|
||||||
|
|
||||||
|
def _get_on_device_or_primary(self):
|
||||||
|
"""Returns value in same replica or device if possible, else the _primary."""
|
||||||
|
replica_id = values_util.get_current_replica_id_as_int()
|
||||||
|
if replica_id is None:
|
||||||
|
# Try to find a value on the current device.
|
||||||
|
current_device = device_util.canonicalize(device_util.current())
|
||||||
|
for i, value in enumerate(self._values):
|
||||||
|
if device_util.canonicalize(value.device) == current_device:
|
||||||
|
return self._get_replica(i)
|
||||||
|
return self._get_replica(0)
|
||||||
|
else:
|
||||||
|
return self._get_replica(replica_id)
|
||||||
|
|
||||||
def read_value(self):
|
def read_value(self):
|
||||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
return array_ops.identity(self._get())
|
return array_ops.identity(self._get())
|
||||||
@ -778,7 +819,8 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
if ds_context.in_cross_replica_context():
|
if ds_context.in_cross_replica_context():
|
||||||
update_replica_id = distribute_lib.get_update_replica_id()
|
update_replica_id = distribute_lib.get_update_replica_id()
|
||||||
if update_replica_id is not None:
|
if update_replica_id is not None:
|
||||||
return update_fn(self._values[update_replica_id], value, **kwargs)
|
replica_value = self._get_replica(update_replica_id)
|
||||||
|
return update_fn(replica_value, value, **kwargs)
|
||||||
return self._update_cross_replica(update_fn, value, **kwargs)
|
return self._update_cross_replica(update_fn, value, **kwargs)
|
||||||
else:
|
else:
|
||||||
values_util.assert_replica_context(self.distribute_strategy)
|
values_util.assert_replica_context(self.distribute_strategy)
|
||||||
@ -802,6 +844,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
|||||||
obj_map[v] = new_obj
|
obj_map[v] = new_obj
|
||||||
resource_map[v.handle] = new_obj.handle
|
resource_map[v.handle] = new_obj.handle
|
||||||
obj_map[self] = new_obj
|
obj_map[self] = new_obj
|
||||||
|
resource_map[self.handle] = new_obj.handle
|
||||||
resource_map[self] = new_obj.handle
|
resource_map[self] = new_obj.handle
|
||||||
return obj_map, resource_map
|
return obj_map, resource_map
|
||||||
|
|
||||||
@ -835,6 +878,12 @@ class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
|
|||||||
def restore(self, restored_tensors, restored_shapes):
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
"""Restore the same value into all variables."""
|
"""Restore the same value into all variables."""
|
||||||
tensor, = restored_tensors
|
tensor, = restored_tensors
|
||||||
|
packed_var = self._mirrored_variable._packed_variable # pylint: disable=protected-access
|
||||||
|
if packed_var is not None:
|
||||||
|
return control_flow_ops.group(
|
||||||
|
tuple(
|
||||||
|
values_util.assign_on_device(d, packed_var, tensor)
|
||||||
|
for d in packed_var.devices))
|
||||||
return control_flow_ops.group(
|
return control_flow_ops.group(
|
||||||
tuple(
|
tuple(
|
||||||
values_util.assign_on_device(v.device, v, tensor)
|
values_util.assign_on_device(v.device, v, tensor)
|
||||||
@ -1013,7 +1062,7 @@ class SyncOnReadVariable(DistributedVariable):
|
|||||||
|
|
||||||
def _get_cross_replica(self):
|
def _get_cross_replica(self):
|
||||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||||
return self._primary
|
return self._get_replica(0)
|
||||||
|
|
||||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||||
return self._distribute_strategy.reduce(
|
return self._distribute_strategy.reduce(
|
||||||
|
@ -42,7 +42,6 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -234,11 +233,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
# TODO(b/137795644): support CentralStroageStrategy
|
# TODO(b/137795644): support CentralStroageStrategy
|
||||||
# strategy_combinations.central_storage_strategy_with_two_gpus,
|
# strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
],
|
],
|
||||||
mode=["eager"]
|
mode=["eager"]))
|
||||||
))
|
|
||||||
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
|
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
|
||||||
if not tf2.enabled():
|
if not tf2.enabled():
|
||||||
self.skipTest("Only V2 is supported.")
|
self.skipTest("Only V2 is supported.")
|
||||||
@ -259,11 +258,11 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
# TODO(b/137795644): support CentralStroageStrategy
|
# TODO(b/137795644): support CentralStroageStrategy
|
||||||
# strategy_combinations.central_storage_strategy_with_two_gpus,
|
# strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
],
|
],
|
||||||
mode=["eager"]
|
mode=["eager"]))
|
||||||
))
|
|
||||||
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
|
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
|
||||||
if not tf2.enabled():
|
if not tf2.enabled():
|
||||||
self.skipTest("Only V2 is supported.")
|
self.skipTest("Only V2 is supported.")
|
||||||
@ -384,6 +383,16 @@ def _make_mirrored():
|
|||||||
return mirrored
|
return mirrored
|
||||||
|
|
||||||
|
|
||||||
|
def mirrored_and_tpu_strategy_combinations():
|
||||||
|
return combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
|
],
|
||||||
|
mode=["graph", "eager"])
|
||||||
|
|
||||||
|
|
||||||
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def _is_per_replica(self, result, expected, klass=values.PerReplica):
|
def _is_per_replica(self, result, expected, klass=values.PerReplica):
|
||||||
@ -563,6 +572,7 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
|||||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
],
|
],
|
||||||
synchronization=[
|
synchronization=[
|
||||||
@ -708,29 +718,40 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.evaluate(
|
self.evaluate(
|
||||||
distribution.experimental_local_results(distribution.run(assign)))
|
distribution.experimental_local_results(distribution.run(assign)))
|
||||||
|
|
||||||
def testPackedVariable(self, distribution, synchronization, aggregation):
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
],
|
||||||
|
mode=["eager"]))
|
||||||
|
class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def testPackedVariable(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v0 = variables_lib.Variable(
|
v0 = variables_lib.Variable(0.)
|
||||||
0., synchronization=synchronization, aggregation=aggregation)
|
self.assertIsNone(v0._packed_var)
|
||||||
if not isinstance(v0, values.DistributedVariable):
|
|
||||||
self.skipTest("This test doesn't apply to non DistributedVariables")
|
|
||||||
|
|
||||||
self.assertEqual(v0._packed_var, None)
|
|
||||||
|
|
||||||
device_type = device.DeviceSpec.from_string(v0._devices[0]).device_type
|
|
||||||
for d in v0._devices:
|
|
||||||
if device.DeviceSpec.from_string(d).device_type != device_type:
|
|
||||||
self.skipTest("Packing variables on devices of different types "
|
|
||||||
"is not supported yet.")
|
|
||||||
|
|
||||||
distribution._enable_packed_variable_in_eager_mode = True
|
distribution._enable_packed_variable_in_eager_mode = True
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v1 = variables_lib.Variable(
|
v1 = variables_lib.Variable(0)
|
||||||
0., synchronization=synchronization, aggregation=aggregation)
|
|
||||||
if ops.executing_eagerly_outside_functions():
|
|
||||||
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
|
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
|
||||||
else:
|
|
||||||
self.assertEqual(v1._packed_var, None)
|
devices = v1._devices
|
||||||
|
for i in range(1, len(devices)):
|
||||||
|
with distribute_lib.ReplicaContext(distribution, i):
|
||||||
|
v1.assign(i)
|
||||||
|
val = v1._get()
|
||||||
|
self.assertIsInstance(val, packed.PackedVarAndDevice)
|
||||||
|
self.assertEqual(val.device, devices[0])
|
||||||
|
self.assertEqual(self.evaluate(val.read_value()), 0)
|
||||||
|
for i in range(0, len(devices)):
|
||||||
|
with distribute_lib.ReplicaContext(distribution, i):
|
||||||
|
val = v1._get()
|
||||||
|
self.assertIsInstance(val, packed.PackedVarAndDevice)
|
||||||
|
self.assertEqual(val.device, devices[i])
|
||||||
|
self.assertEqual(self.evaluate(val.read_value()), i)
|
||||||
|
|
||||||
|
|
||||||
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||||
@ -920,6 +941,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["eager"]))
|
mode=["eager"]))
|
||||||
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
|
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
|
||||||
@ -943,6 +965,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["graph", "eager"]))
|
mode=["graph", "eager"]))
|
||||||
def testValueInReplicaContext(self, distribution):
|
def testValueInReplicaContext(self, distribution):
|
||||||
@ -968,6 +991,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["graph", "eager"]))
|
mode=["graph", "eager"]))
|
||||||
def testAssignOutOfScope(self, distribution):
|
def testAssignOutOfScope(self, distribution):
|
||||||
@ -1041,6 +1065,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["eager"]))
|
mode=["eager"]))
|
||||||
def testInitializedToSameValueInsideEagerRun(self, distribution):
|
def testInitializedToSameValueInsideEagerRun(self, distribution):
|
||||||
@ -1066,6 +1091,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["graph", "eager"]))
|
mode=["graph", "eager"]))
|
||||||
def testAggregationOnlyFirstReplica(self, distribution):
|
def testAggregationOnlyFirstReplica(self, distribution):
|
||||||
@ -1093,6 +1119,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["eager"]))
|
mode=["eager"]))
|
||||||
def testInitScope(self, distribution):
|
def testInitScope(self, distribution):
|
||||||
@ -1143,13 +1170,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
distribution.experimental_local_results(distribution.run(add)))
|
distribution.experimental_local_results(distribution.run(add)))
|
||||||
self.assertAllEqual([2, 2], per_replica_results)
|
self.assertAllEqual([2, 2], per_replica_results)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||||
combinations.combine(
|
|
||||||
distribution=[
|
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
||||||
strategy_combinations.tpu_strategy,
|
|
||||||
],
|
|
||||||
mode=["graph", "eager"]))
|
|
||||||
def testAssignAdd(self, distribution):
|
def testAssignAdd(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v = variable_scope.variable(
|
v = variable_scope.variable(
|
||||||
@ -1456,15 +1477,6 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
|
|||||||
self.assertEqual(2., self.evaluate(add1(replica_local)))
|
self.assertEqual(2., self.evaluate(add1(replica_local)))
|
||||||
|
|
||||||
|
|
||||||
def mirrored_and_tpu_strategy_combinations():
|
|
||||||
return combinations.combine(
|
|
||||||
distribution=[
|
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
||||||
strategy_combinations.tpu_strategy,
|
|
||||||
],
|
|
||||||
mode=["graph", "eager"])
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/144432582): Add variable aggregation type to combinations to simplify
|
# TODO(b/144432582): Add variable aggregation type to combinations to simplify
|
||||||
# tests.
|
# tests.
|
||||||
def strategy_and_run_tf_function_combinations():
|
def strategy_and_run_tf_function_combinations():
|
||||||
@ -1478,6 +1490,7 @@ def strategy_and_run_tf_function_combinations():
|
|||||||
experimental_run_tf_function=[True, False]) + combinations.combine(
|
experimental_run_tf_function=[True, False]) + combinations.combine(
|
||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.tpu_strategy,
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_packed_var,
|
||||||
],
|
],
|
||||||
mode=["graph", "eager"],
|
mode=["graph", "eager"],
|
||||||
experimental_run_tf_function=[True])
|
experimental_run_tf_function=[True])
|
||||||
|
@ -61,6 +61,12 @@ def on_write_assign_sub(var, value, use_locking=False, name=None,
|
|||||||
|
|
||||||
|
|
||||||
def assign_on_each_device(var, assign_func, value, read_value):
|
def assign_on_each_device(var, assign_func, value, read_value):
|
||||||
|
"""Update the variable on each replica with the given assign_func and value."""
|
||||||
|
if var._packed_variable is not None: # pylint: disable=protected-access
|
||||||
|
update = control_flow_ops.group(
|
||||||
|
tuple(
|
||||||
|
assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
update = control_flow_ops.group(
|
update = control_flow_ops.group(
|
||||||
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
|
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
|
||||||
if not read_value:
|
if not read_value:
|
||||||
@ -104,7 +110,7 @@ def on_read_assign_cross_replica(var, value, read_value=True):
|
|||||||
# TODO(anjs): Should this be over all the replicas in sync since we
|
# TODO(anjs): Should this be over all the replicas in sync since we
|
||||||
# call `reduce` on the variable during read?
|
# call `reduce` on the variable during read?
|
||||||
if var.aggregation == vs.VariableAggregation.SUM:
|
if var.aggregation == vs.VariableAggregation.SUM:
|
||||||
tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access
|
tensor = math_ops.cast(tensor / len(var._devices), var.dtype) # pylint: disable=protected-access
|
||||||
return assign_on_each_device(var, assign_on_device, tensor,
|
return assign_on_each_device(var, assign_on_device, tensor,
|
||||||
read_value)
|
read_value)
|
||||||
|
|
||||||
|
@ -298,7 +298,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
self._pivot = pivot
|
self._pivot = pivot
|
||||||
self._replicated_vars = {}
|
self._replicated_vars = {}
|
||||||
|
|
||||||
def get_replicated_var_handle(self, name, vars_, is_mirrored=False):
|
def get_replicated_var_handle(self, name, vars_, is_mirrored=False,
|
||||||
|
is_packed=False):
|
||||||
"""Returns a variable handle for replicated TPU variable 'var'.
|
"""Returns a variable handle for replicated TPU variable 'var'.
|
||||||
|
|
||||||
This is a method used by an experimental replicated variable implementation
|
This is a method used by an experimental replicated variable implementation
|
||||||
@ -309,6 +310,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
vars_: The replicated TPU variables.
|
vars_: The replicated TPU variables.
|
||||||
is_mirrored: Whether the variables are mirrored, which guarantees the
|
is_mirrored: Whether the variables are mirrored, which guarantees the
|
||||||
values in each replica are always the same.
|
values in each replica are always the same.
|
||||||
|
is_packed: Whether the replicated variables are packed into one variable.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The handle of the TPU replicated input node.
|
The handle of the TPU replicated input node.
|
||||||
@ -320,7 +322,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
if handle is not None:
|
if handle is not None:
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
if device_assignment is not None:
|
if device_assignment is not None and not is_packed:
|
||||||
# Find a variable copy for each replica in the device assignment.
|
# Find a variable copy for each replica in the device assignment.
|
||||||
# Note that the order of devices for replicas for the variable and the
|
# Note that the order of devices for replicas for the variable and the
|
||||||
# device assignment might not match.
|
# device assignment might not match.
|
||||||
@ -356,7 +358,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
graph._set_control_flow_context(self.outer_context)
|
graph._set_control_flow_context(self.outer_context)
|
||||||
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
|
handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
|
||||||
name=name + "/handle",
|
name=name + "/handle",
|
||||||
is_mirrored_variable=is_mirrored)
|
is_mirrored_variable=is_mirrored,
|
||||||
|
is_packed=is_packed)
|
||||||
graph._set_control_flow_context(saved_context)
|
graph._set_control_flow_context(saved_context)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self._replicated_vars[name] = handle
|
self._replicated_vars[name] = handle
|
||||||
|
Loading…
Reference in New Issue
Block a user