Merge pull request from ngc92:with_values

PiperOrigin-RevId: 323455948
Change-Id: I65ed255768d45150b373b959d2ae882690b29133
This commit is contained in:
TensorFlower Gardener 2020-07-27 15:44:08 -07:00
commit c7e7f49228
7 changed files with 57 additions and 0 deletions

View File

@ -56,6 +56,9 @@
corresponding bitwise ops. `bool` arguments continue to be supported and
dispatch to logical ops. This brings them more in line with Python and NumPy
benavior.
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
the same sparsity pattern, but with new provided values. It is similar to
the `with_values` function of `RaggedTensor`.
* `tf.data`:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one process

View File

@ -178,6 +178,31 @@ class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor):
"""
return self._values
def with_values(self, new_values):
"""Returns a copy of `self` with `values` replaced by `new_values`.
This method produces a new `SparseTensor` that has the same nonzero
`indices` and same `dense_shape`, but updated values.
Args:
new_values: The values of the new `SparseTensor`. Needs to have the same
shape as the current `.values` `Tensor`. May have a different type than
the current `values`.
Returns:
A `SparseTensor` with identical indices and shape but updated values.
Example usage:
>>> st = tf.sparse.from_dense([[1, 0, 2, 0], [3, 0, 0, 4]])
>>> tf.sparse.to_dense(st.with_values([10, 20, 30, 40])) # 4 nonzero values
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[10, 0, 20, 0],
[30, 0, 0, 40]], dtype=int32)>
"""
return SparseTensor(self._indices, new_values, self._dense_shape)
@property
def op(self):
"""The `Operation` that produces `values` as an output."""

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
@ -97,6 +98,18 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
self.assertIn(dense.op, sp.consumers())
self.assertIn(out.op, sp.consumers())
def testWithValues(self):
source = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1., 2], dense_shape=[3, 4])
new_tensor = source.with_values([5.0, 1.0])
self.assertAllEqual(new_tensor.indices, source.indices)
self.assertAllEqual(new_tensor.values, [5.0, 1.0])
self.assertAllEqual(new_tensor.dense_shape, source.dense_shape)
# ensure new value's shape is checked
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
source.with_values([[5.0, 1.0]])
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):

View File

@ -52,4 +52,8 @@ tf_class {
name: "get_shape"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -52,4 +52,8 @@ tf_class {
name: "get_shape"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -52,4 +52,8 @@ tf_class {
name: "get_shape"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -52,4 +52,8 @@ tf_class {
name: "get_shape"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_values"
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
}
}