added sparse.map_values
This commit is contained in:
parent
26dc5fc653
commit
5e261399de
@ -47,6 +47,7 @@ from tensorflow.python.ops.gen_sparse_ops import *
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
||||
@ -2733,6 +2734,99 @@ def sparse_transpose(sp_input, perm=None, name=None):
|
||||
return transposed_st
|
||||
|
||||
|
||||
@tf_export("sparse.map_values")
|
||||
@dispatch.add_dispatch_support
|
||||
def map_values(op, *args, **kwargs):
|
||||
"""Applies `op` to the values of one or more `SparseTensor`s.
|
||||
|
||||
Replaces any `SparseTensor` in `args` or `kwargs` with its `values`
|
||||
tensor, and then calls `op`. Returns a `SparseTensor` that is constructed
|
||||
from the input `SparseTensor`s' `indices` and the value returned by
|
||||
the `op`.
|
||||
|
||||
If the input arguments contain multiple `SparseTensor`s, then they must have
|
||||
identical `indices`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> st = tf.sparse.from_dense([[1, 2, 0], [0, 4, 0], [1, 0, 0]])
|
||||
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.ones_like, st)).numpy().to_list()
|
||||
[[1, 1, 0], [0, 1, 0], [1, 0, 0]]
|
||||
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.multiply, st, st)).numpy().to_list()
|
||||
[[1, 4, 0], [0, 16, 0], [1, 0, 0]]
|
||||
>>> tf.sparse.to_dense(tf.sparse.map_values(tf.add, st, 5)).numpy().to_list()
|
||||
[[5, 7, 0], [0, 9, 0], [6, 0, 0]]
|
||||
|
||||
Note in particular that even though `tf.add(0, 5) != 0`, implicit zeros
|
||||
will remain unchanged. However, if the sparse tensor contains any explict
|
||||
zeros, these will be affected by the mapping!
|
||||
|
||||
Args:
|
||||
op: The operation that should be applied to the SparseTensor `values`.
|
||||
`op` is typically an element-wise operation (such as math_ops.add), but
|
||||
any operation that preserves the shape can be used.
|
||||
*args: Arguments for `op`.
|
||||
**kwargs: Keyword arguments for `op`.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` whose `indices` matches the `indices` of all
|
||||
input `SparseTensor`s.
|
||||
Raises:
|
||||
ValueError: If args contains no `SparseTensor`, or if the `indices`
|
||||
of the input `SparseTensor`s are not identical.
|
||||
"""
|
||||
sparse_list = []
|
||||
inner_args = _replace_sparse_with_values(args, sparse_list)
|
||||
inner_kwargs = _replace_sparse_with_values(kwargs, sparse_list)
|
||||
if not sparse_list:
|
||||
raise ValueError("No SparseTensor in argument list of map_values")
|
||||
|
||||
with ops.control_dependencies(_assert_sparse_compatible(sparse_list)):
|
||||
# Delegate to op, and then compose the result from the transformed values
|
||||
# and the known indices/dense shape. Since we ensure that indices and shape
|
||||
# are identical, we can just use the firs tone.
|
||||
return sparse_tensor.SparseTensor(sparse_list[0].indices,
|
||||
op(*inner_args, **inner_kwargs), sparse_list[0].dense_shape)
|
||||
|
||||
|
||||
def _assert_sparse_compatible(sparse_tensors):
|
||||
"""Check that all of `sparse_tensors` have same `indices` and `dense_shape`
|
||||
|
||||
Returns: An op to be used as a control dependency.
|
||||
"""
|
||||
checks = []
|
||||
first = sparse_tensors[0]
|
||||
for t in sparse_tensors[1:]:
|
||||
checks.append(check_ops.assert_equal(first.dense_shape, t.dense_shape,
|
||||
message="Mismatched shapes!"))
|
||||
checks.append(check_ops.assert_equal(first.indices, t.indices,
|
||||
message="Mismatched indices!"))
|
||||
return checks
|
||||
|
||||
|
||||
def _replace_sparse_with_values(value, sparse_list):
|
||||
"""Replace `SparseTensor`s with their values in `value`
|
||||
|
||||
Each `SparseTensor` in `value` is replaced by its `values` tensor, and
|
||||
collects all `SparseTensor`s in `sparse_list`.
|
||||
|
||||
Args:
|
||||
value: A structure of `Tensor`s and `SparseTensor`s
|
||||
sparse_list: A list. Output parameter that collects all `SparseTensor`s
|
||||
in `value`.
|
||||
"""
|
||||
flat_vals = nest.flatten(value, expand_composites=False)
|
||||
new_vals = []
|
||||
for v in flat_vals:
|
||||
if isinstance(v, sparse_tensor.SparseTensor):
|
||||
sparse_list.append(v)
|
||||
new_vals.append(v.values)
|
||||
else:
|
||||
new_vals.append(v)
|
||||
return nest.pack_sequence_as(value, new_vals, expand_composites=False)
|
||||
|
||||
|
||||
|
||||
def _add_sparse_to_tensors_map(sp_input,
|
||||
container=None,
|
||||
shared_name=None,
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -180,6 +181,46 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
array_ops.transpose(dense_of_sparse))
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def testMapValues(self):
|
||||
# supplying no sparse tensor should result in ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
sparse_ops.map_values(math_ops.abs, 0.0)
|
||||
|
||||
sp = sparse_ops.from_dense([[0.0, 1.0, 0.0], [-2.0, 1.0, 0.0]])
|
||||
|
||||
# helper function to check equality of sparse tensor
|
||||
def assert_sparse_equal(expected, result):
|
||||
self.assertAllEqual(expected.values, result.values, msg="Values differ")
|
||||
self.assertAllEqual(expected.indices, result.indices, msg="Indices differ")
|
||||
self.assertAllEqual(expected.dense_shape, result.dense_shape, msg="Shapes differ")
|
||||
|
||||
# check for a single sparse argument
|
||||
expected = sparse_ops.from_dense([[0.0, 1.0, 0.0], [2.0, 1.0, 0.0]])
|
||||
result = sparse_ops.map_values(math_ops.abs, sp)
|
||||
assert_sparse_equal(expected, result)
|
||||
|
||||
# check correct passing of keyword argument, and handling of two sparse
|
||||
# arguments at the same time
|
||||
def mapping(arg1, arg2, kwarg):
|
||||
self.assertTrue(kwarg == "kwarg")
|
||||
return arg1 + arg2
|
||||
result = sparse_ops.map_values(mapping, sp, sp, kwarg="kwarg")
|
||||
expected = sparse_ops.from_dense([[0.0, 2.0, 0.0], [-4.0, 2.0, 0.0]])
|
||||
assert_sparse_equal(expected, result)
|
||||
|
||||
# check that index mismatches are correctly detected even if the `value`s
|
||||
# have compatible shape
|
||||
sp_incomp = sparse_ops.from_dense([[0.0, 1.0, 0.0], [-2.0, 0.0, 1.0]])
|
||||
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
||||
result = sparse_ops.map_values(mapping, sp, sp_incomp, kwarg="kwarg")
|
||||
self.evaluate(result)
|
||||
|
||||
# check that shape mismatches are correctly detected
|
||||
sp_incomp = sparse_tensor.SparseTensor(sp.indices, sp.values, (25, 25))
|
||||
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
||||
result = sparse_ops.map_values(mapping, sp, sp_incomp, kwarg="kwarg")
|
||||
self.evaluate(result)
|
||||
|
||||
def testConstantStringToSparse(self):
|
||||
# Test case for GitHub issue 40633.
|
||||
tensor = constant_op.constant(list('ababa'))
|
||||
|
@ -44,6 +44,10 @@ tf_module {
|
||||
name: "from_dense"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_values"
|
||||
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "mask"
|
||||
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "from_dense"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_values"
|
||||
argspec: "args=[\'op\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "mask"
|
||||
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user