Fix error check in TPUEmbedding to work when used in outside compilation.

PiperOrigin-RevId: 316922750
Change-Id: Ie6b4c83e54f3e6d90fbe38fe8da0eea84312c382
This commit is contained in:
Bruce Fontaine 2020-06-17 11:00:38 -07:00 committed by TensorFlower Gardener
parent 2a8bbb92b7
commit c8dd07ae28
2 changed files with 41 additions and 10 deletions

View File

@ -1024,11 +1024,8 @@ class TPUEmbedding(tracking.AutoTrackable):
def _raise_error_for_inputs_not_on_cpu(self, features):
"""Checks all tensors in features to see are placed on the CPU."""
# expand_composites here is important, we need to check the device of each
# underlying tensor.
for path, input_tensor in nest.flatten_with_joined_string_paths(
features, expand_composites=True):
spec = tf_device.DeviceSpec.from_string(input_tensor.device)
def check_device(path, device_string):
spec = tf_device.DeviceSpec.from_string(device_string)
if spec.device_type == "TPU":
raise ValueError(
"Received input tensor {} which is on a TPU input device {}. Input "
@ -1037,7 +1034,18 @@ class TPUEmbedding(tracking.AutoTrackable):
"setting the 'experimental_prefetch_to_device' option of the "
"dataset distribution function. See the documentation of the "
"enqueue method for an example.".format(
path, input_tensor.device))
path, device_string))
# expand_composites here is important, we need to check the device of each
# underlying tensor.
for path, input_tensor in nest.flatten_with_joined_string_paths(
features, expand_composites=True):
if (input_tensor.op.type == "Identity" and
input_tensor.op.inputs[0].op.type == "TPUReplicatedInput"):
for tensor in input_tensor.op.inputs[0].op.inputs:
check_device(path, tensor.device)
else:
check_device(path, input_tensor.device)
def enqueue(self, features, weights=None, training=True, name=None):
"""Enqueues id tensors for embedding lookup.

View File

@ -727,10 +727,33 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
def get_activations():
return mid_level_api.dequeue()
sparse_features = next(sparse_iter)
mid_level_api.enqueue(sparse_features, training=False)
sparse_activations = strategy.run(get_activations)
return sparse_activations
features = next(sparse_iter)
mid_level_api.enqueue(features, training=False)
activations = strategy.run(get_activations)
return activations
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
test_fn()
@parameterized.parameters([True, False])
def test_enqueue_cpu_tensor_with_outside_compilation(self, use_mlir):
if use_mlir:
config.enable_mlir_bridge()
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
input_fn = self._create_dense_input_fn(strategy)
sparse_iter = iter(strategy.experimental_distribute_datasets_from_function(
input_fn))
@def_function.function
def test_fn():
def get_activations(features):
mid_level_api.enqueue(features, training=False)
return mid_level_api.dequeue()
activations = strategy.run(get_activations, args=(next(sparse_iter),))
return activations
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
test_fn()