Merge pull request #20539 from candyzone:master
PiperOrigin-RevId: 214121495
This commit is contained in:
commit
425e96f3ae
@ -112,9 +112,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
|
|||||||
dtype = sparse_weights.dtype if sparse_weights is not None else None
|
dtype = sparse_weights.dtype if sparse_weights is not None else None
|
||||||
if isinstance(embedding_weights, variables.PartitionedVariable):
|
if isinstance(embedding_weights, variables.PartitionedVariable):
|
||||||
embedding_weights = list(embedding_weights)
|
embedding_weights = list(embedding_weights)
|
||||||
embedding_weights = [
|
if not isinstance(embedding_weights[0],
|
||||||
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
|
resource_variable_ops.ResourceVariable):
|
||||||
]
|
embedding_weights = [
|
||||||
|
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
|
||||||
|
]
|
||||||
|
|
||||||
contrib_tensor_util.assert_same_float_dtype(embedding_weights +
|
contrib_tensor_util.assert_same_float_dtype(embedding_weights +
|
||||||
[sparse_weights])
|
[sparse_weights])
|
||||||
|
@ -3433,9 +3433,11 @@ def _safe_embedding_lookup_sparse(embedding_weights,
|
|||||||
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
|
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
|
||||||
|
|
||||||
dtype = sparse_weights.dtype if sparse_weights is not None else None
|
dtype = sparse_weights.dtype if sparse_weights is not None else None
|
||||||
embedding_weights = [
|
if not isinstance(embedding_weights[0],
|
||||||
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
|
resource_variable_ops.ResourceVariable):
|
||||||
]
|
embedding_weights = [
|
||||||
|
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
|
||||||
|
]
|
||||||
|
|
||||||
with ops.name_scope(name, 'embedding_lookup',
|
with ops.name_scope(name, 'embedding_lookup',
|
||||||
embedding_weights + [sparse_ids,
|
embedding_weights + [sparse_ids,
|
||||||
|
@ -550,9 +550,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
|
|||||||
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
|
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
|
||||||
|
|
||||||
dtype = sparse_weights.dtype if sparse_weights is not None else None
|
dtype = sparse_weights.dtype if sparse_weights is not None else None
|
||||||
embedding_weights = [
|
if not isinstance(embedding_weights[0],
|
||||||
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
|
resource_variable_ops.ResourceVariable):
|
||||||
]
|
embedding_weights = [
|
||||||
|
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
|
||||||
|
]
|
||||||
|
|
||||||
with ops.name_scope(name, 'embedding_lookup',
|
with ops.name_scope(name, 'embedding_lookup',
|
||||||
embedding_weights + [sparse_ids,
|
embedding_weights + [sparse_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user