Internal change
PiperOrigin-RevId: 312090528 Change-Id: I474709513b01db8c24c50fd670029451c51cb622
This commit is contained in:
parent
55aee9e550
commit
46f7108d78
@ -129,8 +129,10 @@ class Embedding(Layer):
|
|||||||
# since it knows all kernels using the variable only exist on CPU.
|
# since it knows all kernels using the variable only exist on CPU.
|
||||||
# When eager execution is enabled, the placement decision has to be made
|
# When eager execution is enabled, the placement decision has to be made
|
||||||
# right now. Checking for the presence of GPUs to avoid complicating the
|
# right now. Checking for the presence of GPUs to avoid complicating the
|
||||||
# TPU codepaths which can handle sparse optimizers.
|
# TPU codepaths which can handle sparse optimizers. But if we are within
|
||||||
if context.executing_eagerly() and context.context().num_gpus():
|
# a tf.function, we go back the graph mode logic and rely on the placer.
|
||||||
|
if (context.executing_eagerly() and context.context().num_gpus() and
|
||||||
|
not ops.inside_function()):
|
||||||
with ops.device('cpu:0'):
|
with ops.device('cpu:0'):
|
||||||
self.embeddings = self.add_weight(
|
self.embeddings = self.add_weight(
|
||||||
shape=(self.input_dim, self.output_dim),
|
shape=(self.input_dim, self.output_dim),
|
||||||
|
Loading…
Reference in New Issue
Block a user