Update tpu_embedding_v2.py to use the new API for prefetching data to host memory.

PiperOrigin-RevId: 316742491
Change-Id: I6803c798256578a284d9ef190d79bf2e35f9ce6a
This commit is contained in:
Bruce Fontaine 2020-06-16 13:07:46 -07:00 committed by TensorFlower Gardener
parent 93a441910f
commit 2a0ad47926
2 changed files with 152 additions and 28 deletions

View File

@ -31,6 +31,7 @@ from tensorflow.python.distribute import sharded_variable
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -139,6 +140,18 @@ class TPUEmbedding(tracking.AutoTrackable):
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
When creating a distributed dataset that is to be passed to the enqueue
operation a special input option must be specified:
```python
distributed_dataset = (
strategy.experimental_distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_prefetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
```
To use this API on TPU you should use a custom training loop. Below is an
example of a training and evaluation step:
@ -309,10 +322,6 @@ class TPUEmbedding(tracking.AutoTrackable):
# We need to list of host devices for the load/retrieve operations.
self._hosts = get_list_of_hosts(self._strategy)
# TODO(bfontain) Remove this once we have an official way of splitting
# prefetch between host and device.
self._strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
# We generally use the per core batch size, but will have the user pass
# in a global batch size.
self._batch_size = batch_size // self._strategy.num_replicas_in_sync
@ -507,7 +516,11 @@ class TPUEmbedding(tracking.AutoTrackable):
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = strategy.experimental_distribute_dataset(...)
distributed_dataset = (
strategy.experimental_distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_prefetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
@ -594,7 +607,11 @@ class TPUEmbedding(tracking.AutoTrackable):
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = strategy.experimental_distribute_dataset(...)
distributed_dataset = (
strategy.experimental_distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_prefetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
@ -1004,6 +1021,24 @@ class TPUEmbedding(tracking.AutoTrackable):
input_tensor.op.name,
input_tensor.op.type))
def _raise_error_for_inputs_not_on_cpu(self, features):
"""Checks all tensors in features to see are placed on the CPU."""
# expand_composites here is important, we need to check the device of each
# underlying tensor.
for path, input_tensor in nest.flatten_with_joined_string_paths(
features, expand_composites=True):
spec = tf_device.DeviceSpec.from_string(input_tensor.device)
if spec.device_type == "TPU":
raise ValueError(
"Received input tensor {} which is on a TPU input device {}. Input "
"tensors for TPU embeddings must be placed on the CPU. Please "
"ensure that your dataset is prefetching tensors to the host by "
"setting the 'experimental_prefetch_to_device' option of the "
"dataset distribution function. See the documentation of the "
"enqueue method for an example.".format(
path, input_tensor.device))
def enqueue(self, features, weights=None, training=True, name=None):
"""Enqueues id tensors for embedding lookup.
@ -1021,7 +1056,11 @@ class TPUEmbedding(tracking.AutoTrackable):
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = strategy.experimental_distribute_dataset(...)
distributed_dataset = (
strategy.experimental_distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_prefetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
@ -1091,6 +1130,7 @@ class TPUEmbedding(tracking.AutoTrackable):
flat_weights = nest.flatten(weights)
flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
self._raise_error_for_inputs_not_on_cpu(features)
in_tpu_context = self._raise_error_for_incorrect_control_flow_context()
# If we are in a tpu_context, automatically apply outside compilation.
if in_tpu_context:

View File

@ -28,6 +28,7 @@ import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
@ -443,7 +444,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
def test_pass_none_to_apply_gradients(self):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_sparse_dataset(strategy)
data = next(iter(strategy.experimental_distribute_dataset(dataset)))
data = next(iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))))
@def_function.function
def embedding_and_set_gradients(data):
@ -527,7 +531,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
input_fn = self._create_dense_input_fn(strategy, include_weights=True)
dist = strategy.experimental_distribute_datasets_from_function(input_fn)
dist = strategy.experimental_distribute_datasets_from_function(
input_fn,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))
dist_iter = iter(dist)
@def_function.function
@ -547,8 +554,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
sparse = self._create_sparse_dataset(strategy)
ragged = self._create_ragged_dataset(strategy, include_weights=True)
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
sparse_iter = iter(strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
ragged_iter = iter(strategy.experimental_distribute_dataset(
ragged,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def test_fn():
@ -569,8 +582,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
sparse = self._create_sparse_dataset(strategy, include_weights=True)
ragged = self._create_ragged_dataset(strategy)
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
sparse_iter = iter(strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
ragged_iter = iter(strategy.experimental_distribute_dataset(
ragged,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def test_fn():
@ -591,8 +610,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
sparse = self._create_sparse_dataset(strategy)
ragged = self._create_ragged_dataset(strategy)
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
sparse_iter = iter(strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
ragged_iter = iter(strategy.experimental_distribute_dataset(
ragged,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def test_fn():
@ -613,7 +638,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
sparse = self._create_sparse_dataset(strategy)
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
sparse_iter = iter(strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def test_fn():
@ -633,7 +661,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
sparse = self._create_sparse_dataset(strategy, include_weights=True)
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
sparse_iter = iter(strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def test_fn():
@ -654,8 +685,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
sparse = self._create_sparse_dataset(strategy)
ragged = self._create_ragged_dataset(strategy)
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
sparse_iter = iter(strategy.experimental_distribute_dataset(
sparse,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
ragged_iter = iter(strategy.experimental_distribute_dataset(
ragged,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def test_fn():
@ -678,6 +715,26 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
ragged0 = self._get_replica_numpy(ragged_activations, strategy, 0)
self.assertAllClose(sparse0, ragged0)
def test_enqueue_cpu_tensor(self):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
input_fn = self._create_dense_input_fn(strategy)
sparse_iter = iter(strategy.experimental_distribute_datasets_from_function(
input_fn))
@def_function.function
def test_fn():
def get_activations():
return mid_level_api.dequeue()
sparse_features = next(sparse_iter)
mid_level_api.enqueue(sparse_features, training=False)
sparse_activations = strategy.run(get_activations)
return sparse_activations
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
test_fn()
@parameterized.parameters(True, False)
def test_enqueue_with_weights(self, ragged):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
@ -689,7 +746,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
dataset = self._create_sparse_dataset(strategy, include_weights=True,
weight=weight)
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
dataset_iter = iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def enqueue_and_get(features, weights):
@ -727,7 +787,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_sparse_dataset(strategy)
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
dataset_iter = iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def enqueue_with_outside_compilation(data):
@ -761,7 +824,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_sparse_dataset(strategy)
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
dataset_iter = iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
# This is one way to force the enqueue in some control flow. @tf.functions
# aren't inlined in the calling tf.function. An alternative would be to
@ -785,7 +851,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
def test_enqueue_with_outside_compilation_non_direct_input(self):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_sparse_dataset(strategy)
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
dataset_iter = iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def enqueue_with_outside_compilation():
@ -804,7 +873,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
def test_enqueue_with_outside_compilation_auto_mode(self):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
dataset = self._create_sparse_dataset(strategy)
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
dataset_iter = iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
@def_function.function
def enqueue_with_no_gradient_apply(data):
@ -883,7 +955,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
self._create_strategy_and_mid_level(optimizer_name))
dataset = self._create_sparse_dataset(strategy)
dist = strategy.experimental_distribute_dataset(dataset)
dist = strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))
dist_iter = iter(dist)
@def_function.function
@ -1175,7 +1250,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
input_fn = self._create_dense_input_fn(strategy)
dist = strategy.experimental_distribute_datasets_from_function(input_fn)
dist = strategy.experimental_distribute_datasets_from_function(
input_fn,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))
dist_iter = iter(dist)
@def_function.function
@ -1235,7 +1313,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
def input_fn(ctx):
del ctx
return dataset_ops.DatasetV2.from_tensors(feature).repeat()
dist = strategy.experimental_distribute_datasets_from_function(input_fn)
dist = strategy.experimental_distribute_datasets_from_function(
input_fn,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))
dist_iter = iter(dist)
@def_function.function
@ -1364,7 +1445,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
optimizer=optimizer)
dataset = self._create_sparse_dataset(strategy)
data = next(iter(strategy.experimental_distribute_dataset(dataset)))
data = next(iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))))
@def_function.function
def embedding_and_set_gradients(data):