diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index f7a383c440c..e5cfba7c587 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -1024,11 +1024,8 @@ class TPUEmbedding(tracking.AutoTrackable): def _raise_error_for_inputs_not_on_cpu(self, features): """Checks all tensors in features to see are placed on the CPU.""" - # expand_composites here is important, we need to check the device of each - # underlying tensor. - for path, input_tensor in nest.flatten_with_joined_string_paths( - features, expand_composites=True): - spec = tf_device.DeviceSpec.from_string(input_tensor.device) + def check_device(path, device_string): + spec = tf_device.DeviceSpec.from_string(device_string) if spec.device_type == "TPU": raise ValueError( "Received input tensor {} which is on a TPU input device {}. Input " @@ -1037,7 +1034,18 @@ class TPUEmbedding(tracking.AutoTrackable): "setting the 'experimental_prefetch_to_device' option of the " "dataset distribution function. See the documentation of the " "enqueue method for an example.".format( - path, input_tensor.device)) + path, device_string)) + + # expand_composites here is important, we need to check the device of each + # underlying tensor. + for path, input_tensor in nest.flatten_with_joined_string_paths( + features, expand_composites=True): + if (input_tensor.op.type == "Identity" and + input_tensor.op.inputs[0].op.type == "TPUReplicatedInput"): + for tensor in input_tensor.op.inputs[0].op.inputs: + check_device(path, tensor.device) + else: + check_device(path, input_tensor.device) def enqueue(self, features, weights=None, training=True, name=None): """Enqueues id tensors for embedding lookup. diff --git a/tensorflow/python/tpu/tpu_embedding_v2_test.py b/tensorflow/python/tpu/tpu_embedding_v2_test.py index ebaf2791055..ff09085f3f1 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_test.py @@ -727,10 +727,33 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase): def get_activations(): return mid_level_api.dequeue() - sparse_features = next(sparse_iter) - mid_level_api.enqueue(sparse_features, training=False) - sparse_activations = strategy.run(get_activations) - return sparse_activations + features = next(sparse_iter) + mid_level_api.enqueue(features, training=False) + activations = strategy.run(get_activations) + return activations + + with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'): + test_fn() + + @parameterized.parameters([True, False]) + def test_enqueue_cpu_tensor_with_outside_compilation(self, use_mlir): + if use_mlir: + config.enable_mlir_bridge() + + strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd') + + input_fn = self._create_dense_input_fn(strategy) + sparse_iter = iter(strategy.experimental_distribute_datasets_from_function( + input_fn)) + + @def_function.function + def test_fn(): + def get_activations(features): + mid_level_api.enqueue(features, training=False) + return mid_level_api.dequeue() + + activations = strategy.run(get_activations, args=(next(sparse_iter),)) + return activations with self.assertRaisesRegex(ValueError, 'which is on a TPU input device'): test_fn()