fix: No need to convert to tensor when using ResourceVariable in embedding_lookup,

because ResourceVariable support ResourceGather OP.
This commit is contained in:
candy.dc 2018-07-26 11:36:30 +08:00
parent 5d92abe1e4
commit 9bab0c89c4
3 changed files with 12 additions and 9 deletions

View File

@ -112,9 +112,10 @@ def safe_embedding_lookup_sparse(embedding_weights,
dtype = sparse_weights.dtype if sparse_weights is not None else None
if isinstance(embedding_weights, variables.PartitionedVariable):
embedding_weights = list(embedding_weights)
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable):
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
contrib_tensor_util.assert_same_float_dtype(embedding_weights +
[sparse_weights])

View File

@ -3283,9 +3283,10 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable):
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,

View File

@ -545,9 +545,10 @@ def safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable):
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,