Add tf.sparse.from_dense utility function.

PiperOrigin-RevId: 254124938
This commit is contained in:
Martin Wicke 2019-06-19 20:36:12 -07:00 committed by TensorFlower Gardener
parent 6112891368
commit 0c66f29942
4 changed files with 28 additions and 0 deletions

View File

@ -100,6 +100,17 @@ def _make_int64_tensor(value, name):
return math_ops.cast(value, dtypes.int64)
@tf_export("sparse.from_dense")
def from_dense(tensor, name=None):
with ops.name_scope(name, "dense_to_sparse"):
tensor = ops.convert_to_tensor(tensor)
indices = array_ops.where(
math_ops.not_equal(tensor, array_ops.constant(0, tensor.dtype)))
values = array_ops.gather_nd(tensor, indices)
shape = array_ops.shape(tensor, out_type=dtypes.int64)
return sparse_tensor.SparseTensor(indices, values, shape)
@tf_export("sparse.expand_dims")
def sparse_expand_dims(sp_input, axis=None, name=None):
"""Inserts a dimension of 1 into a tensor's shape.

View File

@ -54,6 +54,15 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
test_one(n, m, True)
test_one(n, m, False)
def testDenseFromConstantToSparse(self):
expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2))
tensor = constant_op.constant(expected_constant)
sparse = sparse_ops.from_dense(tensor)
dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape,
sparse.values)
constant = self.evaluate(dense)
self.assertAllEqual(expected_constant, constant)
def testSparseExpandDims(self):
for rank in range(1, 4):
# Create a dummy input. When rank=3, shape=[2, 4, 6].

View File

@ -36,6 +36,10 @@ tf_module {
name: "fill_empty_rows"
argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_dense"
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "mask"
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -32,6 +32,10 @@ tf_module {
name: "fill_empty_rows"
argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_dense"
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "mask"
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "