Merge pull request #40368 from ngc92:map_values
PiperOrigin-RevId: 322222577 Change-Id: I300afb3b55417201b229a0ee00d59efc2f831bd1
This commit is contained in:
commit
8245cb693b
@ -38,6 +38,7 @@
|
||||
* Calling ops with a python constants or numpy values is now consistent with
|
||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
||||
truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
||||
* `tf.data`:
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
|
@ -47,6 +47,7 @@ from tensorflow.python.ops.gen_sparse_ops import *
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
||||
@ -2772,6 +2773,119 @@ def sparse_transpose(sp_input, perm=None, name=None):
|
||||
return transposed_st
|
||||
|
||||
|
||||
@tf_export("sparse.map_values", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def map_values(op, *args, **kwargs):
|
||||
"""Applies `op` to the `.values` tensor of one or more `SparseTensor`s.
|
||||
|
||||
Replaces any `SparseTensor` in `args` or `kwargs` with its `values`
|
||||
tensor (which contains the non-default values for the SparseTensor),
|
||||
and then calls `op`. Returns a `SparseTensor` that is constructed
|
||||
from the input `SparseTensor`s' `indices`, `dense_shape`, and the
|
||||
value returned by the `op`.
|
||||
|
||||
If the input arguments contain multiple `SparseTensor`s, then they must have
|
||||
equal `indices` and dense shapes.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> s = tf.sparse.from_dense([[1, 2, 0],
|
||||
... [0, 4, 0],
|
||||
... [1, 0, 0]])
|
||||
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.ones_like, s)).numpy()
|
||||
array([[1, 1, 0],
|
||||
[0, 1, 0],
|
||||
[1, 0, 0]], dtype=int32)
|
||||
|
||||
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.multiply, s, s)).numpy()
|
||||
array([[ 1, 4, 0],
|
||||
[ 0, 16, 0],
|
||||
[ 1, 0, 0]], dtype=int32)
|
||||
|
||||
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.add, s, 5)).numpy()
|
||||
array([[6, 7, 0],
|
||||
[0, 9, 0],
|
||||
[6, 0, 0]], dtype=int32)
|
||||
|
||||
Note: even though `tf.add(0, 5) != 0`, implicit zeros
|
||||
will remain unchanged. However, if the sparse tensor contains any explict
|
||||
zeros, these will be affected by the mapping!
|
||||
|
||||
Args:
|
||||
op: The operation that should be applied to the SparseTensor `values`. `op`
|
||||
is typically an element-wise operation (such as math_ops.add), but any
|
||||
operation that preserves the shape can be used.
|
||||
*args: Arguments for `op`.
|
||||
**kwargs: Keyword arguments for `op`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` whose `indices` and `dense_shape` matches the `indices`
|
||||
and `dense_shape` of all input `SparseTensor`s.
|
||||
Raises:
|
||||
ValueError: If args contains no `SparseTensor`, or if the `indices`
|
||||
or `dense_shape`s of the input `SparseTensor`s are not equal.
|
||||
"""
|
||||
sparse_list = []
|
||||
inner_args = _replace_sparse_with_values(args, sparse_list)
|
||||
inner_kwargs = _replace_sparse_with_values(kwargs, sparse_list)
|
||||
if not sparse_list:
|
||||
raise ValueError("No SparseTensor in argument list of map_values")
|
||||
|
||||
with ops.control_dependencies(_assert_sparse_compatible(sparse_list)):
|
||||
# Delegate to op, and then compose the result from the transformed values
|
||||
# and the known indices/dense shape. Since we ensure that indices and shape
|
||||
# are identical, we can just use the first one.
|
||||
return sparse_tensor.SparseTensor(sparse_list[0].indices,
|
||||
op(*inner_args, **inner_kwargs),
|
||||
sparse_list[0].dense_shape)
|
||||
|
||||
|
||||
def _assert_sparse_compatible(sparse_tensors):
|
||||
"""Check that all of `sparse_tensors` have same `indices` and `dense_shape`.
|
||||
|
||||
Args:
|
||||
sparse_tensors: A list of sparse tensors.
|
||||
|
||||
Returns:
|
||||
An op to be used as a control dependency.
|
||||
"""
|
||||
checks = []
|
||||
first = sparse_tensors[0]
|
||||
for t in sparse_tensors[1:]:
|
||||
checks.append(
|
||||
check_ops.assert_equal(
|
||||
first.dense_shape, t.dense_shape, message="Mismatched shapes!"))
|
||||
checks.append(
|
||||
check_ops.assert_equal(
|
||||
first.indices, t.indices, message="Mismatched indices!"))
|
||||
return checks
|
||||
|
||||
|
||||
def _replace_sparse_with_values(value, sparse_list):
|
||||
"""Replace `SparseTensor`s with their values in `value`
|
||||
|
||||
Each `SparseTensor` in `value` is replaced by its `values` tensor, and
|
||||
collects all `SparseTensor`s in `sparse_list`.
|
||||
|
||||
Args:
|
||||
value: A structure of `Tensor`s and `SparseTensor`s
|
||||
sparse_list: A list. Output parameter that collects all `SparseTensor`s in
|
||||
`value`.
|
||||
|
||||
Returns:
|
||||
`value` with each SparseTensor replaced by its `.value` attribute.
|
||||
"""
|
||||
flat_vals = nest.flatten(value, expand_composites=False)
|
||||
new_vals = []
|
||||
for v in flat_vals:
|
||||
if isinstance(v, sparse_tensor.SparseTensor):
|
||||
sparse_list.append(v)
|
||||
new_vals.append(v.values)
|
||||
else:
|
||||
new_vals.append(v)
|
||||
return nest.pack_sequence_as(value, new_vals, expand_composites=False)
|
||||
|
||||
|
||||
def _add_sparse_to_tensors_map(sp_input,
|
||||
container=None,
|
||||
shared_name=None,
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -180,6 +181,49 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
array_ops.transpose(dense_of_sparse))
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def testMapValues(self):
|
||||
# supplying no sparse tensor should result in ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
sparse_ops.map_values(math_ops.abs, 0.0)
|
||||
|
||||
sp = sparse_ops.from_dense([[0.0, 1.0, 0.0], [-2.0, 1.0, 0.0]])
|
||||
|
||||
# helper function to check equality of sparse tensor
|
||||
def assert_sparse_equal(expected, result):
|
||||
self.assertAllEqual(expected.values, result.values, msg='Values differ')
|
||||
self.assertAllEqual(
|
||||
expected.indices, result.indices, msg='Indices differ')
|
||||
self.assertAllEqual(
|
||||
expected.dense_shape, result.dense_shape, msg='Shapes differ')
|
||||
|
||||
# check for a single sparse argument
|
||||
expected = sparse_ops.from_dense([[0.0, 1.0, 0.0], [2.0, 1.0, 0.0]])
|
||||
result = sparse_ops.map_values(math_ops.abs, sp)
|
||||
assert_sparse_equal(expected, result)
|
||||
|
||||
# check correct passing of keyword argument, and handling of two sparse
|
||||
# arguments at the same time
|
||||
def mapping(arg1, arg2, kwarg):
|
||||
self.assertEqual(kwarg, 'kwarg')
|
||||
return arg1 + arg2
|
||||
|
||||
result = sparse_ops.map_values(mapping, sp, sp, kwarg='kwarg')
|
||||
expected = sparse_ops.from_dense([[0.0, 2.0, 0.0], [-4.0, 2.0, 0.0]])
|
||||
assert_sparse_equal(expected, result)
|
||||
|
||||
# check that index mismatches are correctly detected even if the `value`s
|
||||
# have compatible shape
|
||||
sp_incomp = sparse_ops.from_dense([[0.0, 1.0, 0.0], [-2.0, 0.0, 1.0]])
|
||||
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
||||
result = sparse_ops.map_values(mapping, sp, sp_incomp, kwarg='kwarg')
|
||||
self.evaluate(result)
|
||||
|
||||
# check that shape mismatches are correctly detected
|
||||
sp_incomp = sparse_tensor.SparseTensor(sp.indices, sp.values, (25, 25))
|
||||
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
||||
result = sparse_ops.map_values(mapping, sp, sp_incomp, kwarg='kwarg')
|
||||
self.evaluate(result)
|
||||
|
||||
def testConstantStringToSparse(self):
|
||||
# Test case for GitHub issue 40633.
|
||||
tensor = constant_op.constant(list('ababa'))
|
||||
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "from_dense"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_values"
|
||||
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "mask"
|
||||
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user