diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 3f57fd6cb63..e30e93f02dc 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -129,8 +129,10 @@ class Embedding(Layer): # since it knows all kernels using the variable only exist on CPU. # When eager execution is enabled, the placement decision has to be made # right now. Checking for the presence of GPUs to avoid complicating the - # TPU codepaths which can handle sparse optimizers. - if context.executing_eagerly() and context.context().num_gpus(): + # TPU codepaths which can handle sparse optimizers. But if we are within + # a tf.function, we go back the graph mode logic and rely on the placer. + if (context.executing_eagerly() and context.context().num_gpus() and + not ops.inside_function()): with ops.device('cpu:0'): self.embeddings = self.add_weight( shape=(self.input_dim, self.output_dim),