STT-tensorflow/tensorflow/python/distribute/tpu_strategy_test.py
Revan Sopher b38e885409 Preserve TPUDistributedVariables passed to TPUStrategy.run().
Normally run() converts all inputs to Tensor, which uses the current value of a variable.
Disabling the cast doesn't work, as a lower MLIR op will cast the variable to Tensor anyway.
Instead we partially apply any TPUDistributedVariables to the function before replicating the remaining inputs.

This doesn't enable any new functionality per se, as it was always possible to implicitly capture the variables in the Python function, but this matches behavior to MirroredStrategy and avoids a simple usability issue.

Note that this change of behavior might break anyone who was relying on the automatic cast happening, in which case the user-side fix would be to manually cast.

PiperOrigin-RevId: 351712553
Change-Id: If5b88e3d79007a273aa067d11e438025bd9f5ab5
2021-01-13 19:22:09 -08:00

1219 lines
42 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPUStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import tpu_strategy as tpu_lib
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
def get_tpu_cluster_resolver():
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu,
zone=FLAGS.zone,
project=FLAGS.project,
)
return resolver
def get_tpu_strategy(enable_packed_var=False):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_strategy_util.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategyV2(resolver)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
return strategy
# TPU tests which don't use TPUStrategy.
class TPUTest(test.TestCase):
def test_multiple_initialize_system(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_strategy_util.initialize_tpu_system(resolver)
with test.mock.patch.object(logging, "warning") as mock_log:
tpu_strategy_util.initialize_tpu_system(resolver)
self.assertRegex(str(mock_log.call_args), "already been initialized")
def test_tpu_tf_function_same_device(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(1)
@function.defun_with_attributes(attributes={"_noinline": True})
def get_a_plus_one():
return a + 1
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def foo(x):
with ops.device("/device:TPU:0"):
b = x + get_a_plus_one()
return b + 1
result = foo(a)
self.assertAllEqual(4, result)
def test_tpu_return_int32(self):
with ops.device("/device:TPU:0"):
a = variables.Variable(0)
@def_function.function
def foo():
return a + 1
@def_function.function
def bar():
with ops.device("/device:TPU:1"):
return foo()
with ops.device("/device:CPU:0"):
result = bar() + 1
self.assertAllEqual(result, 2)
def test_tpu_output_device(self):
def foo():
return 1 + 1
func1 = function.defun_with_attributes(
foo, attributes={"_XlaMustCompile": False})
func2 = function.defun_with_attributes(
foo, attributes={
"_OutputsOnOpDevice": True,
"_XlaMustCompile": False
})
with ops.device("/device:TPU:0"):
ret1 = func1()
ret2 = func2()
self.assertAllEqual(ret1.backing_device,
"/job:localhost/replica:0/task:0/device:CPU:0")
self.assertAllEqual(ret2.backing_device,
"/job:localhost/replica:0/task:0/device:TPU:0")
def test_on_demand_op_with_dynamic_output(self):
with ops.device("/device:TPU:0"):
where_output = array_ops.where([True, False, True])
self.assertAllEqual(where_output, [[0], [2]])
with ops.device("/device:TPU:0"):
repeat_output = array_ops.repeat(math_ops.range(2), [1, 4])
self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
@parameterized.named_parameters([("PackedVar", True), ("", False)])
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
def test_handle_in_cross_replica_context(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(1.0)
@def_function.function
def func():
self.assertEndsWith(v.handle.device, "device:TPU:0")
return v + 1.0
ret = func()
self.assertAllEqual(ret, 2.0)
def test_function_compile_with_xla(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(1.0)
@def_function.function
def func():
return v.read_value() + 1.0
with ops.device("/device:TPU:0"):
self.assertAllEqual(func(), 2.0)
def test_sequential_runs(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Computation replicated to all cores.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=2)
strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment)
strategy._enable_packed_variable_in_eager_mode = enable_packed_var
# Computation on the 1st core.
device_assignment2 = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
strategy2 = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment2)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=([2., 2.],)))
outputs2 = strategy2.run(
computation, args=([outputs[0]],))
return outputs2
self.assertAllEqual([[16., 16.]], train_step())
def test_device_switch_case(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
a = variables.Variable(1)
inference_iteration = variables.Variable(-1)
def inference_fn(x, i):
return a + x + i
@def_function.function
def run_inference(x):
def do_inference(device, inference_fn, i):
with ops.device(device):
return inference_fn(x, i)
branch_fns = {
0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
}
branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
return control_flow_ops.switch_case(branch_index, branch_fns)
self.assertAllEqual(2., run_inference(1)) # Use TPU core 0.
self.assertAllEqual(3., run_inference(1)) # Use TPU core 1.
def test_recover_from_compilation_failures(self, enable_packed_var):
# TODO(b/148150981): Stop skipping this test once recovery works
# for non-local TPU.
if FLAGS.tpu:
self.skipTest("Recovery fails for non-local TPU, see b/148150981")
# Disable automatic outside compilation.
config.set_soft_device_placement(False)
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def compilation_failure_run():
def computation():
return random_ops.random_gamma([10], [0.5, 1.5])
return strategy.run(computation)
with self.assertRaises(errors.OpError):
compilation_failure_run()
@def_function.function
def good_run():
def computation():
return random_ops.random_normal([10])
return strategy.run(computation)
good_run()
def test_dynamic_shape_with_outside_compilation_failure(
self, enable_packed_var):
# Enable automatic outside compilation.
config.set_soft_device_placement(True)
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
2, drop_remainder=False)
dataset = strategy.experimental_distribute_dataset(dataset)
iterator = iter(dataset)
@def_function.function
def train_fn(iterator):
def step_fn(inputs):
_, inputs = inputs
return math_ops.reduce_sum(inputs)
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaises(errors.InternalError):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
all_core_strategy = tpu_lib.TPUStrategyV2(resolver)
all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
with all_core_strategy.scope():
v = variables.Variable(0.0,
aggregation=variables.VariableAggregation.MEAN)
# Computation on the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Computation on the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
@def_function.function
def train_step():
def step_fn():
return v + 1.0
all_core_strategy.run(step_fn)
r1 = first_core_strategy.run(step_fn)
r2 = second_core_strategy.run(step_fn)
return r1 + r2
train_step()
self.assertAllEqual(2., train_step())
def test_worker_devices_on_subset_cores(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
# Strategy for the 1st core.
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=1)
first_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment)
first_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
# Strategy for the 2nd core.
device_assignment2 = device_assignment_lib.DeviceAssignment(
topology, [[[0, 0, 0, 1]]])
second_core_strategy = tpu_lib.TPUStrategyV2(
resolver, experimental_device_assignment=device_assignment2)
second_core_strategy._enable_packed_variable_in_eager_mode = (
enable_packed_var)
self.assertLen(first_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
"device:TPU:0")
self.assertLen(second_core_strategy.extended.worker_devices, 1)
self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
"device:TPU:1")
def test_control_output_in_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(
0.0, aggregation=variables.VariableAggregation.MEAN)
@def_function.function
def train_step():
def step_fn():
v.assign_add(1)
for _ in math_ops.range(2):
strategy.run(step_fn)
train_step()
self.assertEqual(2.0, v.numpy())
def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def train_step():
def step_fn(prev):
s = prev + 1
return s
def init_fn():
return array_ops.zeros(shape=())
prev = strategy.run(init_fn)
for _ in math_ops.range(10):
prev = strategy.run(step_fn, args=(prev,))
return strategy.reduce(reduce_util.ReduceOp.SUM, prev, axis=None)
sum_val = train_step().numpy().astype(float)
self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
def test_two_clusters_with_same_fn(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
@def_function.function
def foo(x):
return strategy.run(lambda x: x + 1, (x,))
@def_function.function
def bar(x):
foo(x)
return foo(x)
bar(1)
def test_tpu_variable_run_argument(self, enable_packed_var):
# TPUStrategy.run() casts inputs to Tensor, but has logic to preserve
# variables to avoid unintuitive errors.
# Here we test that a TPUDistributedVariable passed to TPUStrategy.run()
# remains a variable.
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
tpu_variable = variables.Variable(1)
def replica_step(first_arg, variable):
del first_arg # Just here to make sure we're not relying on arg position.
if variable is not None:
self.assertIsInstance(variable, tpu_values.TPUDistributedVariable)
@def_function.function
def step():
strategy.run(
replica_step, args=(
2,
tpu_variable,
))
step()
def test_tpu_run_arg_parsing(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
tpu_vars = [variables.Variable(1)]
def only_star_args(*args):
del args
def pos_and_star_args(first_arg, *args):
del first_arg
del args
def named_args(first_arg, second_arg):
del first_arg
del second_arg
def star_args_and_kw_only(*args, kw):
del args
del kw
# pylint:disable=function-redefined
@def_function.function
def step():
strategy.run(only_star_args, args=(2,))
step()
@def_function.function
def step():
strategy.run(named_args, kwargs={"first_arg": 2, "second_arg": 3})
step()
with self.assertRaisesRegex(TypeError, r"got multiple values for argument"):
@def_function.function
def step():
strategy.run(
named_args, args=(1,), kwargs={
"first_arg": 2,
"second_arg": 3
})
step()
with self.assertRaisesRegex(ValueError,
r"cannot handle Variables passed to \*args"):
@def_function.function
def step():
strategy.run(
only_star_args, args=(
2,
tpu_vars,
))
step()
@def_function.function
def step():
strategy.run(pos_and_star_args, args=(2, 3, 4))
step()
@def_function.function
def step():
strategy.run(star_args_and_kw_only, args=(2, 3), kwargs={"kw": tpu_vars})
step()
with self.assertRaisesRegex(ValueError,
r"mix of positional args and \*args"):
@def_function.function
def step():
strategy.run(pos_and_star_args, args=(tpu_vars, 3, 4))
step()
with self.assertRaisesRegex(ValueError, r"Too many positional arguments"):
@def_function.function
def step():
strategy.run(named_args, args=(2, 3, 4))
step()
# pylint:enable=function-redefined
def test_using_external_variable_inside_tf_function(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
v = variables.Variable(2.0)
@def_function.function
def train_step(data):
def computation(inputs):
return inputs + v
return strategy.run(computation, args=(data,))
expected_result = [[x + 2.] for x in range(0, strategy.num_replicas_in_sync)
]
self.assertAllEqual(
expected_result,
strategy.experimental_local_results(train_step(next(input_iterator))))
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
strategy.num_replicas_in_sync, drop_remainder=True)
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
with strategy.scope():
w = variables.Variable(
(0.,),
shape=(1,),
trainable=False,
synchronization=variables.VariableSynchronization.ON_READ,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
@def_function.function
def run(iterator):
def computation(x):
w.assign(x + w)
return w
def all_reduce(x):
ctx = distribution_strategy_context.get_replica_context()
return ctx.all_reduce("SUM", w) + x
outputs = strategy.run(computation, args=(next(iterator),))
outputs2 = strategy.experimental_local_results(
strategy.run(all_reduce, args=(outputs,)))
return outputs2
data = range(0, strategy.num_replicas_in_sync)
data_sum = sum(data)
expected_result = [
[x + data_sum] for x in range(0, strategy.num_replicas_in_sync)
]
self.assertAllEqual(expected_result, run(input_iterator))
self.assertAllEqual((0.,), w.read_value())
def test_run_output_on_device(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
def computation(x):
return math_ops.square(x)
@def_function.function
def train_step():
outputs = strategy.experimental_local_results(
strategy.run(computation, args=(2,)))
return outputs
results = train_step()
self.assertAllEqual([4., 4.], results)
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:0",
results[0].backing_device)
self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
results[1].backing_device)
def test_composite_input_output(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
with strategy.scope():
table = variables.Variable(
initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
@def_function.function
def sparse_lookup(iterator):
def tpu_function(sparse):
# Assumes dense_shape is (2, *)
looked_up = array_ops.gather(table, sparse.values)
segment_sum = math_ops.unsorted_segment_sum(
looked_up, sparse.indices[:, 0], 2)
return sparse, segment_sum
return nest.map_structure(
strategy.experimental_local_results,
strategy.run(tpu_function, args=(next(iterator),)))
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(2)
def make_sparse(_):
return sparse_tensor.SparseTensor(
indices=array_ops.constant([[0, 0], [1, 0], [1, 1]],
dtype=dtypes.int64),
values=array_ops.constant([0, 0, 1], dtype=dtypes.int32),
dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64))
return dataset.map(make_sparse)
dataset = iter(
strategy.distribute_datasets_from_function(
dataset_fn,
distribute_lib.InputOptions(experimental_prefetch_to_device=False)))
sparse, result = sparse_lookup(dataset)
# All replicas return identical reults.
for replica in range(strategy.num_replicas_in_sync):
self.assertIsInstance(sparse[replica], sparse_tensor.SparseTensor)
self.assertAllEqual(sparse[replica].indices, [[0, 0], [1, 0], [1, 1]])
self.assertAllEqual(sparse[replica].values, [0, 0, 1])
self.assertAllEqual(sparse[replica].dense_shape, [2, 2])
self.assertAllEqual(result[replica], [[0.0, 1.0], [3.0, 8.0]])
def test_composite_input_non_flat_output(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
with strategy.scope():
table = variables.Variable(
initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
@def_function.function
def sparse_lookup(iterator):
def tpu_function(sparse):
# Assumes dense_shape is (2, *)
looked_up = array_ops.gather(table, sparse.values)
segment_sum = math_ops.unsorted_segment_sum(
looked_up, sparse.indices[:, 0], 2)
return {"sparse": sparse, "segment_sum": segment_sum}
return nest.map_structure(
strategy.experimental_local_results,
strategy.run(tpu_function, args=(next(iterator),)))
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(2)
def make_sparse(_):
return sparse_tensor.SparseTensor(
indices=array_ops.constant([[0, 0], [1, 0], [1, 1]],
dtype=dtypes.int64),
values=array_ops.constant([0, 0, 1], dtype=dtypes.int32),
dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64))
return dataset.map(make_sparse)
dataset = iter(
strategy.distribute_datasets_from_function(
dataset_fn,
distribute_lib.InputOptions(experimental_prefetch_to_device=False)))
output = sparse_lookup(dataset)
# All replicas return identical reults.
for replica in range(strategy.num_replicas_in_sync):
self.assertIsInstance(output["sparse"][replica],
sparse_tensor.SparseTensor)
self.assertAllEqual(output["sparse"][replica].indices,
[[0, 0], [1, 0], [1, 1]])
self.assertAllEqual(output["sparse"][replica].values, [0, 0, 1])
self.assertAllEqual(output["sparse"][replica].dense_shape, [2, 2])
self.assertAllEqual(output["segment_sum"][replica],
[[0.0, 1.0], [3.0, 8.0]])
def test_composite_input_dynamic_shapes_outside_compilation(
self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
if strategy.num_replicas_in_sync != 2:
self.skipTest("Test assumes two replicas.")
table = variables.Variable(
initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
@def_function.function
def sparse_lookup(iterator):
def tpu_function(sparse):
lookup = tpu.outside_compilation(
embedding_ops.safe_embedding_lookup_sparse, table, sparse)
return math_ops.reduce_sum(lookup, axis=0)
return strategy.experimental_local_results(
strategy.run(tpu_function, args=(next(iterator),)))
def dataset_fn(_):
dataset = dataset_ops.Dataset.range(2)
def make_sparse(i):
indices = array_ops.constant([[0, 0], [1, 0], [1, 1]],
dtype=dtypes.int64)[0:2 + i]
values = array_ops.constant([0, 0, 1], dtype=dtypes.int32)[0:2 + i]
shape = [
array_ops.constant([2], dtype=dtypes.int64),
array_ops.expand_dims(1 + i, axis=0)
]
dense_shape = array_ops.concat(shape, axis=0)
return sparse_tensor.SparseTensor(
indices=indices, values=values, dense_shape=dense_shape)
return dataset.map(make_sparse)
dataset = iter(
strategy.distribute_datasets_from_function(
dataset_fn,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
result = sparse_lookup(dataset)
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var):
# Define trace_count as a list to avoid python scoping error
trace_count = [0]
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
variable = variables.Variable(0.0)
@def_function.function
def add_one():
trace_count[0] = trace_count[0] + 1
return math_ops.add(variable, constant_op.constant(1.0))
@def_function.function
def update_variable():
for device in set(strategy.extended.worker_devices):
with ops.device(device):
add_one()
with strategy.scope():
update_variable.get_concrete_function()
self.assertLen(strategy.extended.worker_devices, trace_count[0])
class TPUStrategyDataPrefetchTest(test.TestCase):
def test_prefetch_to_device_default(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
# Check default, should prefetch to TPU.
dataset_item = next(iter(strategy.experimental_distribute_dataset(dataset)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "TPU")
def test_prefetch_to_device_tpu(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=True)
dataset_item = next(iter(strategy.experimental_distribute_dataset(
dataset, options=input_options)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "TPU")
def test_prefetch_to_device_cpu(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
# Should be CPU when prefetch_to_device is False.
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=False)
dataset_item = next(iter(strategy.experimental_distribute_dataset(
dataset, options=input_options)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "CPU")
def test_prefetch_to_device_sparse_dataset(self):
strategy = get_tpu_strategy()
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
values=[1, 2, 3],
dense_shape=[2, 2]))
dataset = dataset.repeat()
dataset = dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_dataset(dataset))
def test_prefetch_to_device_ragged_dataset(self):
strategy = get_tpu_strategy()
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3],
row_splits=[0, 2, 3]))
dataset = dataset.repeat()
dataset = dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_dataset(dataset))
def test_prefetch_to_device_sparse_dataset_fn(self):
strategy = get_tpu_strategy()
def dataset_fn(ctx):
del ctx
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
values=[1, 2, 3],
dense_shape=[2, 2]))
dataset = dataset.repeat()
return dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.distribute_datasets_from_function(dataset_fn))
def test_prefetch_to_device_ragged_dataset_fn(self):
strategy = get_tpu_strategy()
def dataset_fn(ctx):
del ctx
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3],
row_splits=[0, 2, 3]))
dataset = dataset.repeat()
return dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.distribute_datasets_from_function(dataset_fn))
class TPUStrategyDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase):
def test_update_config_proto(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_strategy_util.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategyV2(resolver)
config_proto = config_pb2.ConfigProto()
cluster_spec = server_lib.ClusterSpec({"worker": ["fake1", "fake2"]})
with test.mock.patch.object(
resolver, "cluster_spec", return_value=cluster_spec):
new_config = strategy.update_config_proto(config_proto)
# Verify cluster_def.
self.assertProtoEquals(cluster_spec.as_cluster_def(),
new_config.cluster_def)
# Verify isolate_session_state
self.assertTrue(new_config.isolate_session_state)
def test_make_input_fn_iterable(self):
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
distribution = get_tpu_strategy()
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
self._test_input_fn_iterable(distribution, input_fn, expected_values)
def test_make_input_fn_iterator(self):
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
distribution = get_tpu_strategy()
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator,
distribution.extended.worker_devices,
expected_values)
def test_num_replicas_in_sync(self):
strategy = get_tpu_strategy()
self.assertEqual(2, strategy.num_replicas_in_sync)
def test_call_and_merge_exceptions(self):
strategy = get_tpu_strategy()
self._test_call_and_merge_exceptions(strategy)
def test_numpy_dataset(self):
strategy = get_tpu_strategy()
self._test_numpy_dataset(strategy, run_in_function=True)
def test_global_step_update(self):
strategy = get_tpu_strategy()
self._test_global_step_update(strategy)
def test_run(self):
strategy = get_tpu_strategy()
self._test_run(strategy, run_in_function=True)
def test_summary_for_replica_zero_only(self):
strategy = get_tpu_strategy()
self._test_summary_for_replica_zero_only(strategy)
def test_all_reduce_sum(self):
strategy = get_tpu_strategy()
self._test_all_reduce_sum(strategy, run_in_function=True)
def test_all_reduce_sum_gradients(self):
strategy = get_tpu_strategy()
self._test_all_reduce_sum_gradients(strategy, run_in_function=True)
def test_all_reduce_sum_gradient_tape(self):
strategy = get_tpu_strategy()
self._test_all_reduce_sum_gradient_tape(strategy, run_in_function=True)
def test_all_reduce_mean(self):
strategy = get_tpu_strategy()
self._test_all_reduce_mean(strategy, run_in_function=True)
def test_all_reduce_mean_gradients(self):
strategy = get_tpu_strategy()
self._test_all_reduce_mean_gradients(strategy, run_in_function=True)
def test_all_reduce_mean_gradient_tape(self):
strategy = get_tpu_strategy()
self._test_all_reduce_mean_gradient_tape(strategy, run_in_function=True)
def test_reduce(self):
strategy = get_tpu_strategy()
inputs = strategy.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensor_slices([2., 3.]))
self.evaluate(inputs.initialize())
per_replica_outputs = strategy.run(
def_function.function(math_ops.square), args=(next(inputs),))
with strategy.scope():
mean = strategy.reduce(reduce_util.ReduceOp.MEAN, per_replica_outputs,
axis=None)
self.assertEqual(6.5, self.evaluate(mean))
def test_constraint(self):
strategy = get_tpu_strategy()
with strategy.scope():
variable = variables.Variable(initial_value=2.,
constraint=lambda x: 0. * x + 1.)
self.assertEqual(variable.value().numpy(), 2)
@def_function.function
def update_variable():
variable.assign_add(1)
variable.assign(variable.constraint(variable))
update_variable()
self.assertEqual(variable.value().numpy(), 1)
def test_trainable_variables(self):
strategy = get_tpu_strategy()
self._test_trainable_variable(strategy)
def test_model_parallelism(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]])
strategy = tpu_lib.TPUStrategyV2(
resolver,
experimental_device_assignment=device_assignment)
with strategy.scope():
v = variables.Variable(2.)
with strategy.extended.experimental_logical_device(1):
w = variables.Variable(3.)
self.assertLen(strategy.experimental_local_results(v), 1)
self.assertLen(strategy.experimental_local_results(w), 1)
self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0",
strategy.experimental_local_results(v)[0].device)
self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1",
strategy.experimental_local_results(w)[0].device)
logical_devices = []
@def_function.function
def f(x):
replica_ctx = distribution_strategy_context.get_replica_context()
with replica_ctx.experimental_logical_device(0):
y = v * x
with replica_ctx.experimental_logical_device(1):
z = w * y
logical_devices.append((y.device, z.device))
return z
result = strategy.run(f, args=(5.,))
self.assertEqual(
[("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")],
logical_devices)
with self.cached_session():
self.evaluate(variables.global_variables_initializer())
self.assertEqual(30., self.evaluate(result))
def test_model_parallelism_checkpointing(self):
class PartitionedModel(module.Module):
def __init__(self, v, w):
super(PartitionedModel, self).__init__()
assert distribution_strategy_context.has_strategy()
strategy = distribution_strategy_context.get_strategy()
with strategy.extended.experimental_logical_device(0):
self.v = variables.Variable(v)
with strategy.extended.experimental_logical_device(1):
self.w = variables.Variable(w)
def __call__(self, x):
replica_ctx = distribution_strategy_context.get_replica_context()
with replica_ctx.experimental_logical_device(0):
y = self.v * x
with replica_ctx.experimental_logical_device(1):
z = self.w * y
return z
def change_weights_op(self, v_new, w_new):
return control_flow_ops.group([self.v.assign(v_new),
self.w.assign(w_new)])
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]])
strategy = tpu_lib.TPUStrategyV2(
resolver,
experimental_device_assignment=device_assignment)
with strategy.scope():
model = PartitionedModel(2., 3.)
checkpoint_dir = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = util.Checkpoint(model=model)
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
checkpoint.save(file_prefix=checkpoint_prefix)
self.evaluate(model.change_weights_op(1., 4.))
result = strategy.run(def_function.function(model), args=(5.0,))
self.assertEqual(20., self.evaluate(result))
status = checkpoint.restore(
checkpoint_management.latest_checkpoint(checkpoint_dir))
status.run_restore_ops(sess) # must run restore op in non-eager mode.
status.assert_consumed()
status.assert_existing_objects_matched()
result = strategy.run(def_function.function(model), args=(5.0,))
self.assertEqual(30., self.evaluate(result))
class DeviceAssignmentTest(test.TestCase):
def test_core_assignment(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0]]])
self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
self.assertEqual(1, device_assignment.num_cores_per_replica)
self.assertEqual(1, device_assignment.num_replicas)
self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device())
self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device())
def test_device_assignment_strategy_properties(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 0]]])
strategy = tpu_lib.TPUStrategyV2(
resolver,
experimental_device_assignment=device_assignment)
self.assertEqual(strategy.extended.num_hosts, 1)
self.assertEqual(strategy.num_replicas_in_sync, 1)
self.assertEqual(strategy.extended.num_replicas_per_host, 1) # pylint: disable=protected-access
def test_device_assignment_constants(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = device_assignment_lib.DeviceAssignment(
topology,
core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
self.assertEqual(1, device_assignment.num_cores_per_replica)
self.assertEqual(1, device_assignment.num_replicas)
self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device())
self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device())
def test_variables_mismatched_device_assignment(self):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
strategy0 = tpu_lib.TPUStrategyV2(resolver)
self.assertEqual(
("/job:localhost/replica:0/task:0/device:TPU:0",
"/job:localhost/replica:0/task:0/device:TPU:1"),
strategy0.extended.worker_devices)
with strategy0.scope():
v = variables.Variable(1.)
v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.)
with self.cached_session():
self.evaluate(variables.global_variables_initializer())
self.evaluate(v1_assign_op)
self.assertAllEqual([1., 42.],
self.evaluate(
strategy0.experimental_local_results(v)))
# Second strategy has devices reversed relative to the first.
device_assignment = device_assignment_lib.DeviceAssignment(
topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]])
strategy1 = tpu_lib.TPUStrategyV2(
resolver,
experimental_device_assignment=device_assignment)
self.assertEqual(
("/job:localhost/replica:0/task:0/device:TPU:1",
"/job:localhost/replica:0/task:0/device:TPU:0"),
strategy1.extended.worker_devices)
v_read = strategy1.run(def_function.function(v.read_value))
with self.cached_session():
self.assertAllEqual([42., 1.],
self.evaluate(
strategy0.experimental_local_results(v_read)))
if __name__ == "__main__":
test.main()