Going back to forcing embedding layer variables on the CPU even within a tf.function as this is breaking some user code.

PiperOrigin-RevId: 321607029
Change-Id: Id159867f51b26e6604a1186d9ce526658ddd1e19
This commit is contained in:
Rohan Jain 2020-07-16 11:21:21 -07:00 committed by Geeta Chavan
parent bb3c460114
commit d60d7d3c7e
1 changed files with 1 additions and 2 deletions

View File

@ -132,8 +132,7 @@ class Embedding(Layer):
# right now. Checking for the presence of GPUs to avoid complicating the
# TPU codepaths which can handle sparse optimizers. But if we are within
# a tf.function, we go back the graph mode logic and rely on the placer.
if (context.executing_eagerly() and context.context().num_gpus() and
not ops.inside_function()):
if context.executing_eagerly() and context.context().num_gpus():
with ops.device('cpu:0'):
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),