From c13a0bf808a4145fb2ffc894f4af1c95f21dbc79 Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Mon, 11 May 2020 13:08:35 +0300 Subject: [PATCH 1/5] add a with_values method to SparseTensor --- tensorflow/python/framework/sparse_tensor.py | 15 +++++++++++++++ tensorflow/python/framework/sparse_tensor_test.py | 8 ++++++++ 2 files changed, 23 insertions(+) diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 76cb24f2cc6..ab7afabeae5 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -178,6 +178,21 @@ class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor): """ return self._values + def with_values(self, new_values): + """Returns a copy of `self` with `values` replaced by `new_values`. + + This method produces a new `SparseTensor` that has the same nonzero + indices, but updated values. + + Args: + new_values: The values of the new `SparseTensor. Needs to have the same + shape as the current `.values` `Tensor`. + + Returns: + A `SparseTensor` with identical indices but updated values. + """ + return SparseTensor(self._indices, new_values, self._dense_shape) + @property def op(self): """The `Operation` that produces `values` as an output.""" diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index 0d18af1fe2f..4db32065960 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -97,6 +97,14 @@ class SparseTensorTest(test_util.TensorFlowTestCase): self.assertIn(dense.op, sp.consumers()) self.assertIn(out.op, sp.consumers()) + def testWithValues(self): + source = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1., 2], dense_shape=[3, 4]) + new_tensor = tensor.with_values([5.0, 1.0]) + self.assertAllEqual(new_tensor.indices, source.indices) + self.assertAllEqual(new_tensor.values, [5.0, 1.0]) + self.assertAllEqual(new_tensor.dense_shape, source.dense_shape) + class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase): From b3d9b905e15df29817907ebe0c4f0ebe4819468c Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Tue, 12 May 2020 22:06:58 +0300 Subject: [PATCH 2/5] fixed typo --- tensorflow/python/framework/sparse_tensor_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index 4db32065960..71693087c28 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -100,7 +100,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase): def testWithValues(self): source = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1., 2], dense_shape=[3, 4]) - new_tensor = tensor.with_values([5.0, 1.0]) + new_tensor = source.with_values([5.0, 1.0]) self.assertAllEqual(new_tensor.indices, source.indices) self.assertAllEqual(new_tensor.values, [5.0, 1.0]) self.assertAllEqual(new_tensor.dense_shape, source.dense_shape) From 996b28e8341ff5698535ea05ee1b33d82e9d7bb0 Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Sat, 6 Jun 2020 20:45:54 +0300 Subject: [PATCH 3/5] addressing the review issues --- tensorflow/python/framework/sparse_tensor.py | 9 +++++---- tensorflow/python/framework/sparse_tensor_test.py | 5 +++++ .../api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt | 4 ++++ .../api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt | 4 ++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index ab7afabeae5..1592f7d6604 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -182,14 +182,15 @@ class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor): """Returns a copy of `self` with `values` replaced by `new_values`. This method produces a new `SparseTensor` that has the same nonzero - indices, but updated values. + `indices` and same `dense_shape`, but updated values. Args: - new_values: The values of the new `SparseTensor. Needs to have the same - shape as the current `.values` `Tensor`. + new_values: The values of the new `SparseTensor`. Needs to have the same + shape as the current `.values` `Tensor`. May have a different type + than the current `values`. Returns: - A `SparseTensor` with identical indices but updated values. + A `SparseTensor` with identical indices and shape but updated values. """ return SparseTensor(self._indices, new_values, self._dense_shape) diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index 71693087c28..736543f669b 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape @@ -105,6 +106,10 @@ class SparseTensorTest(test_util.TensorFlowTestCase): self.assertAllEqual(new_tensor.values, [5.0, 1.0]) self.assertAllEqual(new_tensor.dense_shape, source.dense_shape) + # ensure new value's shape is checked + with self.assertRaises((errors.InvalidArgumentError, ValueError)): + source.with_values([[5.0, 1.0]]) + class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt index a49cd1ccc4d..e13dad8be69 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt @@ -48,6 +48,10 @@ tf_class { name: "from_value" argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "with_values" + argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_shape" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt index a49cd1ccc4d..e13dad8be69 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt @@ -48,6 +48,10 @@ tf_class { name: "from_value" argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "with_values" + argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_shape" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" From 6be0e40016b01fbe4d56c981205e4ae6c5281864 Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Wed, 22 Jul 2020 13:41:36 +0300 Subject: [PATCH 4/5] updated goldens --- .../tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt | 4 ++++ .../api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt | 8 ++++---- .../tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt | 4 ++++ .../api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt | 8 ++++---- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt index aa89308999c..fe3a8222353 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt @@ -52,4 +52,8 @@ tf_class { name: "get_shape" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "with_values" + argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt index e13dad8be69..f0efebb3c8b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt @@ -48,12 +48,12 @@ tf_class { name: "from_value" argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "with_values" - argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "get_shape" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "with_values" + argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt index aa89308999c..fe3a8222353 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt @@ -52,4 +52,8 @@ tf_class { name: "get_shape" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "with_values" + argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt index e13dad8be69..f0efebb3c8b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt @@ -48,12 +48,12 @@ tf_class { name: "from_value" argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "with_values" - argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "get_shape" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "with_values" + argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None" + } } From 0fb2791e0f1148d47718e2e4d638120cbf9b541b Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Wed, 22 Jul 2020 14:08:13 +0300 Subject: [PATCH 5/5] fixed indentation --- .../python/framework/sparse_tensor_test.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index 736543f669b..aee0894d85d 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -99,16 +99,16 @@ class SparseTensorTest(test_util.TensorFlowTestCase): self.assertIn(out.op, sp.consumers()) def testWithValues(self): - source = sparse_tensor.SparseTensor( - indices=[[0, 0], [1, 2]], values=[1., 2], dense_shape=[3, 4]) - new_tensor = source.with_values([5.0, 1.0]) - self.assertAllEqual(new_tensor.indices, source.indices) - self.assertAllEqual(new_tensor.values, [5.0, 1.0]) - self.assertAllEqual(new_tensor.dense_shape, source.dense_shape) + source = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1., 2], dense_shape=[3, 4]) + new_tensor = source.with_values([5.0, 1.0]) + self.assertAllEqual(new_tensor.indices, source.indices) + self.assertAllEqual(new_tensor.values, [5.0, 1.0]) + self.assertAllEqual(new_tensor.dense_shape, source.dense_shape) - # ensure new value's shape is checked - with self.assertRaises((errors.InvalidArgumentError, ValueError)): - source.with_values([[5.0, 1.0]]) + # ensure new value's shape is checked + with self.assertRaises((errors.InvalidArgumentError, ValueError)): + source.with_values([[5.0, 1.0]]) class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):