diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 3444b3a7665..defa03409a2 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -132,8 +132,7 @@ class Embedding(Layer): # right now. Checking for the presence of GPUs to avoid complicating the # TPU codepaths which can handle sparse optimizers. But if we are within # a tf.function, we go back the graph mode logic and rely on the placer. - if (context.executing_eagerly() and context.context().num_gpus() and - not ops.inside_function()): + if context.executing_eagerly() and context.context().num_gpus(): with ops.device('cpu:0'): self.embeddings = self.add_weight( shape=(self.input_dim, self.output_dim),