Add tf.sparse.from_dense utility function.
PiperOrigin-RevId: 254124938
This commit is contained in:
parent
6112891368
commit
0c66f29942
@ -100,6 +100,17 @@ def _make_int64_tensor(value, name):
|
||||
return math_ops.cast(value, dtypes.int64)
|
||||
|
||||
|
||||
@tf_export("sparse.from_dense")
|
||||
def from_dense(tensor, name=None):
|
||||
with ops.name_scope(name, "dense_to_sparse"):
|
||||
tensor = ops.convert_to_tensor(tensor)
|
||||
indices = array_ops.where(
|
||||
math_ops.not_equal(tensor, array_ops.constant(0, tensor.dtype)))
|
||||
values = array_ops.gather_nd(tensor, indices)
|
||||
shape = array_ops.shape(tensor, out_type=dtypes.int64)
|
||||
return sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
|
||||
@tf_export("sparse.expand_dims")
|
||||
def sparse_expand_dims(sp_input, axis=None, name=None):
|
||||
"""Inserts a dimension of 1 into a tensor's shape.
|
||||
|
@ -54,6 +54,15 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
test_one(n, m, True)
|
||||
test_one(n, m, False)
|
||||
|
||||
def testDenseFromConstantToSparse(self):
|
||||
expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2))
|
||||
tensor = constant_op.constant(expected_constant)
|
||||
sparse = sparse_ops.from_dense(tensor)
|
||||
dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape,
|
||||
sparse.values)
|
||||
constant = self.evaluate(dense)
|
||||
self.assertAllEqual(expected_constant, constant)
|
||||
|
||||
def testSparseExpandDims(self):
|
||||
for rank in range(1, 4):
|
||||
# Create a dummy input. When rank=3, shape=[2, 4, 6].
|
||||
|
@ -36,6 +36,10 @@ tf_module {
|
||||
name: "fill_empty_rows"
|
||||
argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_dense"
|
||||
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "mask"
|
||||
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -32,6 +32,10 @@ tf_module {
|
||||
name: "fill_empty_rows"
|
||||
argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_dense"
|
||||
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "mask"
|
||||
argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user