Merge pull request #39508 from ngc92:with_values
PiperOrigin-RevId: 323455948 Change-Id: I65ed255768d45150b373b959d2ae882690b29133
This commit is contained in:
commit
c7e7f49228
@ -56,6 +56,9 @@
|
||||
corresponding bitwise ops. `bool` arguments continue to be supported and
|
||||
dispatch to logical ops. This brings them more in line with Python and NumPy
|
||||
benavior.
|
||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
|
||||
the same sparsity pattern, but with new provided values. It is similar to
|
||||
the `with_values` function of `RaggedTensor`.
|
||||
* `tf.data`:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
|
||||
|
@ -178,6 +178,31 @@ class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor):
|
||||
"""
|
||||
return self._values
|
||||
|
||||
def with_values(self, new_values):
|
||||
"""Returns a copy of `self` with `values` replaced by `new_values`.
|
||||
|
||||
This method produces a new `SparseTensor` that has the same nonzero
|
||||
`indices` and same `dense_shape`, but updated values.
|
||||
|
||||
Args:
|
||||
new_values: The values of the new `SparseTensor`. Needs to have the same
|
||||
shape as the current `.values` `Tensor`. May have a different type than
|
||||
the current `values`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with identical indices and shape but updated values.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> st = tf.sparse.from_dense([[1, 0, 2, 0], [3, 0, 0, 4]])
|
||||
>>> tf.sparse.to_dense(st.with_values([10, 20, 30, 40])) # 4 nonzero values
|
||||
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||
array([[10, 0, 20, 0],
|
||||
[30, 0, 0, 40]], dtype=int32)>
|
||||
|
||||
"""
|
||||
return SparseTensor(self._indices, new_values, self._dense_shape)
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
"""The `Operation` that produces `values` as an output."""
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -97,6 +98,18 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertIn(dense.op, sp.consumers())
|
||||
self.assertIn(out.op, sp.consumers())
|
||||
|
||||
def testWithValues(self):
|
||||
source = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1., 2], dense_shape=[3, 4])
|
||||
new_tensor = source.with_values([5.0, 1.0])
|
||||
self.assertAllEqual(new_tensor.indices, source.indices)
|
||||
self.assertAllEqual(new_tensor.values, [5.0, 1.0])
|
||||
self.assertAllEqual(new_tensor.dense_shape, source.dense_shape)
|
||||
|
||||
# ensure new value's shape is checked
|
||||
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
||||
source.with_values([[5.0, 1.0]])
|
||||
|
||||
|
||||
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -52,4 +52,8 @@ tf_class {
|
||||
name: "get_shape"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -52,4 +52,8 @@ tf_class {
|
||||
name: "get_shape"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -52,4 +52,8 @@ tf_class {
|
||||
name: "get_shape"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -52,4 +52,8 @@ tf_class {
|
||||
name: "get_shape"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_values"
|
||||
argspec: "args=[\'self\', \'new_values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user