Fix the embedding variable name on TPU for v2 shared embedding columns to match the name on the CPU side.
PiperOrigin-RevId: 254019702
This commit is contained in:
parent
1b0d51abed
commit
6523a23d55
@ -502,7 +502,7 @@ class _TPUSharedEmbeddingColumnV2(_TPUBaseEmbeddingColumn,
|
||||
# Note that in Feature Column V2, shared embeddings have no scope.
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
self.shared_embedding_column_creator._name,
|
||||
is_shared_embedding=True)
|
||||
return tensor
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user