Fix error check in TPUEmbedding to work when used in outside compilation.
PiperOrigin-RevId: 316922750 Change-Id: Ie6b4c83e54f3e6d90fbe38fe8da0eea84312c382
This commit is contained in:
parent
2a8bbb92b7
commit
c8dd07ae28
@ -1024,11 +1024,8 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
def _raise_error_for_inputs_not_on_cpu(self, features):
|
||||
"""Checks all tensors in features to see are placed on the CPU."""
|
||||
|
||||
# expand_composites here is important, we need to check the device of each
|
||||
# underlying tensor.
|
||||
for path, input_tensor in nest.flatten_with_joined_string_paths(
|
||||
features, expand_composites=True):
|
||||
spec = tf_device.DeviceSpec.from_string(input_tensor.device)
|
||||
def check_device(path, device_string):
|
||||
spec = tf_device.DeviceSpec.from_string(device_string)
|
||||
if spec.device_type == "TPU":
|
||||
raise ValueError(
|
||||
"Received input tensor {} which is on a TPU input device {}. Input "
|
||||
@ -1037,7 +1034,18 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
"setting the 'experimental_prefetch_to_device' option of the "
|
||||
"dataset distribution function. See the documentation of the "
|
||||
"enqueue method for an example.".format(
|
||||
path, input_tensor.device))
|
||||
path, device_string))
|
||||
|
||||
# expand_composites here is important, we need to check the device of each
|
||||
# underlying tensor.
|
||||
for path, input_tensor in nest.flatten_with_joined_string_paths(
|
||||
features, expand_composites=True):
|
||||
if (input_tensor.op.type == "Identity" and
|
||||
input_tensor.op.inputs[0].op.type == "TPUReplicatedInput"):
|
||||
for tensor in input_tensor.op.inputs[0].op.inputs:
|
||||
check_device(path, tensor.device)
|
||||
else:
|
||||
check_device(path, input_tensor.device)
|
||||
|
||||
def enqueue(self, features, weights=None, training=True, name=None):
|
||||
"""Enqueues id tensors for embedding lookup.
|
||||
|
@ -727,10 +727,33 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
def get_activations():
|
||||
return mid_level_api.dequeue()
|
||||
|
||||
sparse_features = next(sparse_iter)
|
||||
mid_level_api.enqueue(sparse_features, training=False)
|
||||
sparse_activations = strategy.run(get_activations)
|
||||
return sparse_activations
|
||||
features = next(sparse_iter)
|
||||
mid_level_api.enqueue(features, training=False)
|
||||
activations = strategy.run(get_activations)
|
||||
return activations
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
|
||||
test_fn()
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_enqueue_cpu_tensor_with_outside_compilation(self, use_mlir):
|
||||
if use_mlir:
|
||||
config.enable_mlir_bridge()
|
||||
|
||||
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
|
||||
|
||||
input_fn = self._create_dense_input_fn(strategy)
|
||||
sparse_iter = iter(strategy.experimental_distribute_datasets_from_function(
|
||||
input_fn))
|
||||
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
def get_activations(features):
|
||||
mid_level_api.enqueue(features, training=False)
|
||||
return mid_level_api.dequeue()
|
||||
|
||||
activations = strategy.run(get_activations, args=(next(sparse_iter),))
|
||||
return activations
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'):
|
||||
test_fn()
|
||||
|
Loading…
Reference in New Issue
Block a user