added sparse.map_values

This commit is contained in:
ngc92 2020-06-09 18:51:16 +03:00
parent 26dc5fc653
commit 5e261399de
4 changed files with 143 additions and 0 deletions

View File

@ -47,6 +47,7 @@ from tensorflow.python.ops.gen_sparse_ops import *
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
@ -2733,6 +2734,99 @@ def sparse_transpose(sp_input, perm=None, name=None):
return transposed_st
@tf_export("sparse.map_values")
@dispatch.add_dispatch_support
def map_values(op, *args, **kwargs):
"""Applies `op` to the values of one or more `SparseTensor`s.
Replaces any `SparseTensor` in `args` or `kwargs` with its `values`
tensor, and then calls `op`. Returns a `SparseTensor` that is constructed
from the input `SparseTensor`s' `indices` and the value returned by
the `op`.
If the input arguments contain multiple `SparseTensor`s, then they must have
identical `indices`.
Examples:
>>> st = tf.sparse.from_dense([[1, 2, 0], [0, 4, 0], [1, 0, 0]])
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.ones_like, st)).numpy().to_list()
[[1, 1, 0], [0, 1, 0], [1, 0, 0]]
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.multiply, st, st)).numpy().to_list()
[[1, 4, 0], [0, 16, 0], [1, 0, 0]]
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.add, st, 5)).numpy().to_list()
[[5, 7, 0], [0, 9, 0], [6, 0, 0]]
Note in particular that even though `tf.add(0, 5) != 0`, implicit zeros
will remain unchanged. However, if the sparse tensor contains any explict
zeros, these will be affected by the mapping!
Args:
op: The operation that should be applied to the SparseTensor `values`.
`op` is typically an element-wise operation (such as math_ops.add), but
any operation that preserves the shape can be used.
*args: Arguments for `op`.
**kwargs: Keyword arguments for `op`.
Returns:
A `SparseTensor` whose `indices` matches the `indices` of all
input `SparseTensor`s.
Raises:
ValueError: If args contains no `SparseTensor`, or if the `indices`
of the input `SparseTensor`s are not identical.
"""
sparse_list = []
inner_args = _replace_sparse_with_values(args, sparse_list)
inner_kwargs = _replace_sparse_with_values(kwargs, sparse_list)
if not sparse_list:
raise ValueError("No SparseTensor in argument list of map_values")
with ops.control_dependencies(_assert_sparse_compatible(sparse_list)):
# Delegate to op, and then compose the result from the transformed values
# and the known indices/dense shape. Since we ensure that indices and shape
# are identical, we can just use the firs tone.
return sparse_tensor.SparseTensor(sparse_list[0].indices,
op(*inner_args, **inner_kwargs), sparse_list[0].dense_shape)
def _assert_sparse_compatible(sparse_tensors):
"""Check that all of `sparse_tensors` have same `indices` and `dense_shape`
Returns: An op to be used as a control dependency.
"""
checks = []
first = sparse_tensors[0]
for t in sparse_tensors[1:]:
checks.append(check_ops.assert_equal(first.dense_shape, t.dense_shape,
message="Mismatched shapes!"))
checks.append(check_ops.assert_equal(first.indices, t.indices,
message="Mismatched indices!"))
return checks
def _replace_sparse_with_values(value, sparse_list):
"""Replace `SparseTensor`s with their values in `value`
Each `SparseTensor` in `value` is replaced by its `values` tensor, and
collects all `SparseTensor`s in `sparse_list`.
Args:
value: A structure of `Tensor`s and `SparseTensor`s
sparse_list: A list. Output parameter that collects all `SparseTensor`s
in `value`.
"""
flat_vals = nest.flatten(value, expand_composites=False)
new_vals = []
for v in flat_vals:
if isinstance(v, sparse_tensor.SparseTensor):
sparse_list.append(v)
new_vals.append(v.values)
else:
new_vals.append(v)
return nest.pack_sequence_as(value, new_vals, expand_composites=False)
def _add_sparse_to_tensors_map(sp_input,
container=None,
shared_name=None,

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
@ -180,6 +181,46 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
array_ops.transpose(dense_of_sparse))
self.assertAllEqual(expected, result)
def testMapValues(self):
# supplying no sparse tensor should result in ValueError
with self.assertRaises(ValueError):
sparse_ops.map_values(math_ops.abs, 0.0)
sp = sparse_ops.from_dense([[0.0, 1.0, 0.0], [-2.0, 1.0, 0.0]])
# helper function to check equality of sparse tensor
def assert_sparse_equal(expected, result):
self.assertAllEqual(expected.values, result.values, msg="Values differ")
self.assertAllEqual(expected.indices, result.indices, msg="Indices differ")
self.assertAllEqual(expected.dense_shape, result.dense_shape, msg="Shapes differ")
# check for a single sparse argument
expected = sparse_ops.from_dense([[0.0, 1.0, 0.0], [2.0, 1.0, 0.0]])
result = sparse_ops.map_values(math_ops.abs, sp)
assert_sparse_equal(expected, result)
# check correct passing of keyword argument, and handling of two sparse
# arguments at the same time
def mapping(arg1, arg2, kwarg):
self.assertTrue(kwarg == "kwarg")
return arg1 + arg2
result = sparse_ops.map_values(mapping, sp, sp, kwarg="kwarg")
expected = sparse_ops.from_dense([[0.0, 2.0, 0.0], [-4.0, 2.0, 0.0]])
assert_sparse_equal(expected, result)
# check that index mismatches are correctly detected even if the `value`s
# have compatible shape
sp_incomp = sparse_ops.from_dense([[0.0, 1.0, 0.0], [-2.0, 0.0, 1.0]])
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
result = sparse_ops.map_values(mapping, sp, sp_incomp, kwarg="kwarg")
self.evaluate(result)
# check that shape mismatches are correctly detected
sp_incomp = sparse_tensor.SparseTensor(sp.indices, sp.values, (25, 25))
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
result = sparse_ops.map_values(mapping, sp, sp_incomp, kwarg="kwarg")
self.evaluate(result)
def testConstantStringToSparse(self):
# Test case for GitHub issue 40633.
tensor = constant_op.constant(list('ababa'))

View File

@ -44,6 +44,10 @@ tf_module {
name: "from_dense"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "map_values"
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "mask"
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -40,6 +40,10 @@ tf_module {
name: "from_dense"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "map_values"
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "mask"
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "