Update tpu_embedding_v2.py to use the new API for prefetching data to host memory.
PiperOrigin-RevId: 316742491 Change-Id: I6803c798256578a284d9ef190d79bf2e35f9ce6a
This commit is contained in:
parent
93a441910f
commit
2a0ad47926
@ -31,6 +31,7 @@ from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -139,6 +140,18 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
|
||||
```
|
||||
|
||||
When creating a distributed dataset that is to be passed to the enqueue
|
||||
operation a special input option must be specified:
|
||||
|
||||
```python
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dataset_iterator = iter(distributed_dataset)
|
||||
```
|
||||
|
||||
To use this API on TPU you should use a custom training loop. Below is an
|
||||
example of a training and evaluation step:
|
||||
|
||||
@ -309,10 +322,6 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
# We need to list of host devices for the load/retrieve operations.
|
||||
self._hosts = get_list_of_hosts(self._strategy)
|
||||
|
||||
# TODO(bfontain) Remove this once we have an official way of splitting
|
||||
# prefetch between host and device.
|
||||
self._strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
||||
|
||||
# We generally use the per core batch size, but will have the user pass
|
||||
# in a global batch size.
|
||||
self._batch_size = batch_size // self._strategy.num_replicas_in_sync
|
||||
@ -507,7 +516,11 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
with strategy.scope():
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||
|
||||
distributed_dataset = strategy.experimental_distribute_dataset(...)
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dataset_iterator = iter(distributed_dataset)
|
||||
|
||||
@tf.function
|
||||
@ -594,7 +607,11 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
with strategy.scope():
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||
|
||||
distributed_dataset = strategy.experimental_distribute_dataset(...)
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dataset_iterator = iter(distributed_dataset)
|
||||
|
||||
@tf.function
|
||||
@ -1004,6 +1021,24 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
input_tensor.op.name,
|
||||
input_tensor.op.type))
|
||||
|
||||
def _raise_error_for_inputs_not_on_cpu(self, features):
|
||||
"""Checks all tensors in features to see are placed on the CPU."""
|
||||
|
||||
# expand_composites here is important, we need to check the device of each
|
||||
# underlying tensor.
|
||||
for path, input_tensor in nest.flatten_with_joined_string_paths(
|
||||
features, expand_composites=True):
|
||||
spec = tf_device.DeviceSpec.from_string(input_tensor.device)
|
||||
if spec.device_type == "TPU":
|
||||
raise ValueError(
|
||||
"Received input tensor {} which is on a TPU input device {}. Input "
|
||||
"tensors for TPU embeddings must be placed on the CPU. Please "
|
||||
"ensure that your dataset is prefetching tensors to the host by "
|
||||
"setting the 'experimental_prefetch_to_device' option of the "
|
||||
"dataset distribution function. See the documentation of the "
|
||||
"enqueue method for an example.".format(
|
||||
path, input_tensor.device))
|
||||
|
||||
def enqueue(self, features, weights=None, training=True, name=None):
|
||||
"""Enqueues id tensors for embedding lookup.
|
||||
|
||||
@ -1021,7 +1056,11 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
with strategy.scope():
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||
|
||||
distributed_dataset = strategy.experimental_distribute_dataset(...)
|
||||
distributed_dataset = (
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
dataset_fn=...,
|
||||
options=tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dataset_iterator = iter(distributed_dataset)
|
||||
|
||||
@tf.function
|
||||
@ -1091,6 +1130,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
flat_weights = nest.flatten(weights)
|
||||
flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
|
||||
|
||||
self._raise_error_for_inputs_not_on_cpu(features)
|
||||
in_tpu_context = self._raise_error_for_incorrect_control_flow_context()
|
||||
# If we are in a tpu_context, automatically apply outside compilation.
|
||||
if in_tpu_context:
|
||||
|
@ -28,6 +28,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
@ -443,7 +444,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
def test_pass_none_to_apply_gradients(self):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
data = next(iter(strategy.experimental_distribute_dataset(dataset)))
|
||||
data = next(iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))))
|
||||
|
||||
@def_function.function
|
||||
def embedding_and_set_gradients(data):
|
||||
@ -527,7 +531,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy, include_weights=True)
|
||||
dist = strategy.experimental_distribute_datasets_from_function(input_fn)
|
||||
dist = strategy.experimental_distribute_datasets_from_function(
|
||||
input_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dist_iter = iter(dist)
|
||||
|
||||
@def_function.function
|
||||
@ -547,8 +554,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
sparse = self._create_sparse_dataset(strategy)
|
||||
ragged = self._create_ragged_dataset(strategy, include_weights=True)
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(
|
||||
sparse,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(
|
||||
ragged,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -569,8 +582,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
sparse = self._create_sparse_dataset(strategy, include_weights=True)
|
||||
ragged = self._create_ragged_dataset(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(
|
||||
sparse,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(
|
||||
ragged,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -591,8 +610,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
sparse = self._create_sparse_dataset(strategy)
|
||||
ragged = self._create_ragged_dataset(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(
|
||||
sparse,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(
|
||||
ragged,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -613,7 +638,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
sparse = self._create_sparse_dataset(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(
|
||||
sparse,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -633,7 +661,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
sparse = self._create_sparse_dataset(strategy, include_weights=True)
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(
|
||||
sparse,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -654,8 +685,14 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
sparse = self._create_sparse_dataset(strategy)
|
||||
ragged = self._create_ragged_dataset(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(sparse))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(ragged))
|
||||
sparse_iter = iter(strategy.experimental_distribute_dataset(
|
||||
sparse,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
ragged_iter = iter(strategy.experimental_distribute_dataset(
|
||||
ragged,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
@ -678,6 +715,26 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
ragged0 = self._get_replica_numpy(ragged_activations, strategy, 0)
|
||||
self.assertAllClose(sparse0, ragged0)
|
||||
|
||||
def test_enqueue_cpu_tensor(self):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_datasets_from_function(
|
||||
input_fn))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
def get_activations():
|
||||
return mid_level_api.dequeue()
|
||||
|
||||
sparse_features = next(sparse_iter)
|
||||
mid_level_api.enqueue(sparse_features, training=False)
|
||||
sparse_activations = strategy.run(get_activations)
|
||||
return sparse_activations
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
|
||||
test_fn()
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
def test_enqueue_with_weights(self, ragged):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
@ -689,7 +746,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
dataset = self._create_sparse_dataset(strategy, include_weights=True,
|
||||
weight=weight)
|
||||
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def enqueue_and_get(features, weights):
|
||||
@ -727,7 +787,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def enqueue_with_outside_compilation(data):
|
||||
@ -761,7 +824,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
# This is one way to force the enqueue in some control flow. @tf.functions
|
||||
# aren't inlined in the calling tf.function. An alternative would be to
|
||||
@ -785,7 +851,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
def test_enqueue_with_outside_compilation_non_direct_input(self):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def enqueue_with_outside_compilation():
|
||||
@ -804,7 +873,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
def test_enqueue_with_outside_compilation_auto_mode(self):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
dataset_iter = iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
@def_function.function
|
||||
def enqueue_with_no_gradient_apply(data):
|
||||
@ -883,7 +955,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
self._create_strategy_and_mid_level(optimizer_name))
|
||||
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
dist = strategy.experimental_distribute_dataset(dataset)
|
||||
dist = strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dist_iter = iter(dist)
|
||||
|
||||
@def_function.function
|
||||
@ -1175,7 +1250,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy)
|
||||
dist = strategy.experimental_distribute_datasets_from_function(input_fn)
|
||||
dist = strategy.experimental_distribute_datasets_from_function(
|
||||
input_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dist_iter = iter(dist)
|
||||
|
||||
@def_function.function
|
||||
@ -1235,7 +1313,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
def input_fn(ctx):
|
||||
del ctx
|
||||
return dataset_ops.DatasetV2.from_tensors(feature).repeat()
|
||||
dist = strategy.experimental_distribute_datasets_from_function(input_fn)
|
||||
dist = strategy.experimental_distribute_datasets_from_function(
|
||||
input_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))
|
||||
dist_iter = iter(dist)
|
||||
|
||||
@def_function.function
|
||||
@ -1364,7 +1445,10 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
optimizer=optimizer)
|
||||
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
data = next(iter(strategy.experimental_distribute_dataset(dataset)))
|
||||
data = next(iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False))))
|
||||
|
||||
@def_function.function
|
||||
def embedding_and_set_gradients(data):
|
||||
|
Loading…
Reference in New Issue
Block a user