From 9bab0c89c4ffeeb780e7a3dc415ab888164b9b00 Mon Sep 17 00:00:00 2001 From: "candy.dc" Date: Thu, 26 Jul 2018 11:36:30 +0800 Subject: [PATCH 1/2] fix: No need to convert to tensor when using ResourceVariable in embedding_lookup, because ResourceVariable support ResourceGather OP. --- tensorflow/contrib/layers/python/layers/embedding_ops.py | 7 ++++--- tensorflow/python/feature_column/feature_column_v2.py | 7 ++++--- tensorflow/python/ops/embedding_ops.py | 7 ++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 60e1d85ea9c..897aed527da 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -112,9 +112,10 @@ def safe_embedding_lookup_sparse(embedding_weights, dtype = sparse_weights.dtype if sparse_weights is not None else None if isinstance(embedding_weights, variables.PartitionedVariable): embedding_weights = list(embedding_weights) - embedding_weights = [ - ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights - ] + if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable): + embedding_weights = [ + ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights + ] contrib_tensor_util.assert_same_float_dtype(embedding_weights + [sparse_weights]) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index b4dd23f58de..220a4f7ed66 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -3283,9 +3283,10 @@ def _safe_embedding_lookup_sparse(embedding_weights, raise ValueError('Missing embedding_weights %s.' % embedding_weights) dtype = sparse_weights.dtype if sparse_weights is not None else None - embedding_weights = [ - ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights - ] + if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable): + embedding_weights = [ + ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights + ] with ops.name_scope(name, 'embedding_lookup', embedding_weights + [sparse_ids, diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 27c2fa70176..fe422f5095d 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -545,9 +545,10 @@ def safe_embedding_lookup_sparse(embedding_weights, raise ValueError('Missing embedding_weights %s.' % embedding_weights) dtype = sparse_weights.dtype if sparse_weights is not None else None - embedding_weights = [ - ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights - ] + if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable): + embedding_weights = [ + ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights + ] with ops.name_scope(name, 'embedding_lookup', embedding_weights + [sparse_ids, From ba5d214a6b5d131b693eff277cc3b56298a4721a Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 21 Sep 2018 15:30:56 -0700 Subject: [PATCH 2/2] Fix lint errors --- tensorflow/contrib/layers/python/layers/embedding_ops.py | 3 ++- tensorflow/python/feature_column/feature_column_v2.py | 3 ++- tensorflow/python/ops/embedding_ops.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 897aed527da..17ee8c0733d 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -112,7 +112,8 @@ def safe_embedding_lookup_sparse(embedding_weights, dtype = sparse_weights.dtype if sparse_weights is not None else None if isinstance(embedding_weights, variables.PartitionedVariable): embedding_weights = list(embedding_weights) - if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable): + if not isinstance(embedding_weights[0], + resource_variable_ops.ResourceVariable): embedding_weights = [ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights ] diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 220a4f7ed66..1a2213707cb 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -3283,7 +3283,8 @@ def _safe_embedding_lookup_sparse(embedding_weights, raise ValueError('Missing embedding_weights %s.' % embedding_weights) dtype = sparse_weights.dtype if sparse_weights is not None else None - if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable): + if not isinstance(embedding_weights[0], + resource_variable_ops.ResourceVariable): embedding_weights = [ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights ] diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index fe422f5095d..bcd135eedb1 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -545,7 +545,8 @@ def safe_embedding_lookup_sparse(embedding_weights, raise ValueError('Missing embedding_weights %s.' % embedding_weights) dtype = sparse_weights.dtype if sparse_weights is not None else None - if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable): + if not isinstance(embedding_weights[0], + resource_variable_ops.ResourceVariable): embedding_weights = [ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights ]