From d60d7d3c7eb124fd8ff27fdfffc12147d474ad52 Mon Sep 17 00:00:00 2001
From: Rohan Jain <rohanj@google.com>
Date: Thu, 16 Jul 2020 11:21:21 -0700
Subject: [PATCH] Going back to forcing embedding layer variables on the CPU
 even within a tf.function as this is breaking some user code.

PiperOrigin-RevId: 321607029
Change-Id: Id159867f51b26e6604a1186d9ce526658ddd1e19
---
 tensorflow/python/keras/layers/embeddings.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 3444b3a7665..defa03409a2 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -132,8 +132,7 @@ class Embedding(Layer):
     # right now. Checking for the presence of GPUs to avoid complicating the
     # TPU codepaths which can handle sparse optimizers. But if we are within
     # a tf.function, we go back the graph mode logic and rely on the placer.
-    if (context.executing_eagerly() and context.context().num_gpus() and
-        not ops.inside_function()):
+    if context.executing_eagerly() and context.context().num_gpus():
       with ops.device('cpu:0'):
         self.embeddings = self.add_weight(
             shape=(self.input_dim, self.output_dim),