Merge pull request #41636 from geetachavan1/cherrypicks_F4YCK
[CherryPick 2.3] Going back to forcing embedding layer variables on the CPU even within a tf.function as this is breaking some user code.
This commit is contained in:
commit
ca2b7ba75c
@ -132,8 +132,7 @@ class Embedding(Layer):
|
||||
# right now. Checking for the presence of GPUs to avoid complicating the
|
||||
# TPU codepaths which can handle sparse optimizers. But if we are within
|
||||
# a tf.function, we go back the graph mode logic and rely on the placer.
|
||||
if (context.executing_eagerly() and context.context().num_gpus() and
|
||||
not ops.inside_function()):
|
||||
if context.executing_eagerly() and context.context().num_gpus():
|
||||
with ops.device('cpu:0'):
|
||||
self.embeddings = self.add_weight(
|
||||
shape=(self.input_dim, self.output_dim),
|
||||
|
Loading…
Reference in New Issue
Block a user