Support Sparse-Sparse cwise ops; use for tf.sparse_{minimum,maximum}().
This change adds the CPU kernel and Python ifaces. For now, assumes both operands have the same shapes. Change: 126348349
This commit is contained in:
parent
379df09118
commit
d04c05def5
@ -1562,6 +1562,7 @@ tf_kernel_libraries(
|
||||
"sparse_concat_op",
|
||||
"sparse_reduce_sum_op",
|
||||
"sparse_dense_binary_op_shared",
|
||||
"sparse_sparse_binary_op_shared",
|
||||
"sparse_reorder_op",
|
||||
"sparse_reshape_op",
|
||||
"sparse_softmax",
|
||||
|
@ -54,31 +54,32 @@ class SparseAddOp : public OpKernel {
|
||||
b_values_t->shape().DebugString()));
|
||||
auto a_values = ctx->input(1).vec<T>();
|
||||
auto b_values = ctx->input(4).vec<T>();
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()) &&
|
||||
TensorShapeUtils::IsVector(b_shape->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Input shape should be a vector but received shapes ",
|
||||
a_shape->shape().DebugString(), " and ",
|
||||
b_shape->shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, a_values.size() == a_nnz && b_values.size() == b_nnz,
|
||||
errors::InvalidArgument("Expected ", a_nnz, " and ", b_nnz,
|
||||
" non-empty input values, got ",
|
||||
a_values.size(), " and ", b_values.size()));
|
||||
|
||||
OP_REQUIRES(ctx, a_shape->dims() == b_shape->dims(),
|
||||
OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()) &&
|
||||
TensorShapeUtils::IsVector(b_shape->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Ranks of input tensors must match, but saw ranks: ",
|
||||
a_shape->dims(), " and ", b_shape->dims()));
|
||||
for (int i = 0; i < a_shape->dims(); ++i) {
|
||||
OP_REQUIRES(ctx, a_shape->dim_size(i) == b_shape->dim_size(i),
|
||||
"Input shapes should be a vector but received shapes ",
|
||||
a_shape->shape().DebugString(), " and ",
|
||||
b_shape->shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, a_shape->IsSameSize(*b_shape),
|
||||
errors::InvalidArgument(
|
||||
"Operands do not have the same ranks; got shapes: ",
|
||||
a_shape->SummarizeValue(10), " and ", b_shape->SummarizeValue(10)));
|
||||
const auto a_shape_flat = a_shape->flat<int64>();
|
||||
const auto b_shape_flat = b_shape->flat<int64>();
|
||||
for (int i = 0; i < a_shape->NumElements(); ++i) {
|
||||
OP_REQUIRES(ctx, a_shape_flat(i) == b_shape_flat(i),
|
||||
errors::InvalidArgument(
|
||||
"Input shapes must match: got ", a_shape->dim_size(i),
|
||||
" and ", b_shape->dim_size(i), " for dimension ", i));
|
||||
"Operands' shapes do not match: got ", a_shape_flat(i),
|
||||
" and ", b_shape_flat(i), " for dimension ", i));
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->input("thresh", &thresh_t));
|
||||
|
230
tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
Normal file
230
tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
Normal file
@ -0,0 +1,230 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// SparseSparseBinaryOpShared is the shared code for binary coefficient-wise
|
||||
// (cwise) operations of the following form:
|
||||
//
|
||||
// sparse_t <binary cwise op> sparse_t -> new sparse_t
|
||||
//
|
||||
// The output SparseTensor may store up to "a_nnz + b_nnz" elements.
|
||||
|
||||
// IMPLEMENTATION DETAILS (not part of the interface specification).
|
||||
//
|
||||
// This kernel implements the "union" semantics on the non-zeros: namely, any
|
||||
// non-zero from either side participate in the calculations, and any resultant
|
||||
// zeros will NOT be excluded from the output storage.
|
||||
//
|
||||
// (In the future, we could always add a pruning op the prunes away the zeros,
|
||||
// if desirable.)
|
||||
|
||||
// See docs of all registered ops in ../ops/sparse_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
namespace {
|
||||
// Unions the sparse indices and outputs corresponding values: namely, if a
|
||||
// non-zero appear in one side, it will participate in the calculation, where
|
||||
// the counterpart on the other side is either a value or an implicit zero.
|
||||
//
|
||||
// On exit, outputs the augmented values in "{a,b}_augmented_values", and fills
|
||||
// "entries_to_copy" with "(from_a?, index)" pairs. All three vectors have the
|
||||
// same size.
|
||||
//
|
||||
// The input and output sparse tensors are assumed ordered in the canonical
|
||||
// row-major order.
|
||||
template <typename T>
|
||||
void UnionSparseIndicesAndValues(
|
||||
typename TTypes<int64>::ConstMatrix a_indices_mat,
|
||||
typename TTypes<T>::ConstFlat a_values, int64 a_nnz,
|
||||
typename TTypes<int64>::ConstMatrix b_indices_mat,
|
||||
typename TTypes<T>::ConstFlat b_values, int64 b_nnz, int num_dims,
|
||||
std::vector<T> *a_augmented_values, std::vector<T> *b_augmented_values,
|
||||
std::vector<std::pair<bool, int64>> *entries_to_copy) {
|
||||
entries_to_copy->reserve(a_nnz + b_nnz);
|
||||
a_augmented_values->reserve(a_nnz);
|
||||
b_augmented_values->reserve(b_nnz);
|
||||
|
||||
int64 i = 0, j = 0;
|
||||
const T kZero = T(0);
|
||||
while (i < a_nnz && j < b_nnz) {
|
||||
switch (sparse::DimComparator::cmp(a_indices_mat, b_indices_mat, i, j,
|
||||
num_dims)) {
|
||||
case -1:
|
||||
entries_to_copy->emplace_back(true, i);
|
||||
a_augmented_values->push_back(a_values(i));
|
||||
b_augmented_values->push_back(kZero);
|
||||
++i;
|
||||
break;
|
||||
case 0:
|
||||
entries_to_copy->emplace_back(true, i);
|
||||
a_augmented_values->push_back(a_values(i));
|
||||
b_augmented_values->push_back(b_values(j));
|
||||
++i;
|
||||
++j;
|
||||
break;
|
||||
case 1:
|
||||
entries_to_copy->emplace_back(false, j);
|
||||
a_augmented_values->push_back(kZero);
|
||||
b_augmented_values->push_back(b_values(j));
|
||||
++j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Handles leftovers; at most one loop runs.
|
||||
while (i < a_nnz) {
|
||||
entries_to_copy->emplace_back(/* is_a */ true, i);
|
||||
a_augmented_values->push_back(a_values(i++));
|
||||
b_augmented_values->push_back(kZero);
|
||||
}
|
||||
while (j < b_nnz) {
|
||||
entries_to_copy->emplace_back(/* is_a */ false, j);
|
||||
a_augmented_values->push_back(kZero);
|
||||
b_augmented_values->push_back(b_values(j++));
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
// Device: CPUDevice. GPU kernel is not supported currently.
|
||||
// T: dtype of the SparseTensor's.
|
||||
// Functor: binary cwise operation to perform on the corresponding operand
|
||||
// values. See cwise_ops.h for a list of possible functors to register with.
|
||||
template <typename Device, typename T, typename Functor>
|
||||
class SparseSparseBinaryOpShared : public OpKernel {
|
||||
public:
|
||||
explicit SparseSparseBinaryOpShared(OpKernelConstruction *ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext *ctx) override {
|
||||
const Tensor *a_indices_t, *a_values_t, *a_shape_t, *b_indices_t,
|
||||
*b_values_t, *b_shape_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices_t));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape_t));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("b_indices", &b_indices_t));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape_t));
|
||||
|
||||
// Validations.
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(a_indices_t->shape()) &&
|
||||
TensorShapeUtils::IsMatrix(b_indices_t->shape()),
|
||||
errors::InvalidArgument("Inputs a_indices and b_indices should be "
|
||||
"matrices but received shapes: ",
|
||||
a_indices_t->shape().DebugString(), ", ",
|
||||
b_indices_t->shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values_t->shape()) &&
|
||||
TensorShapeUtils::IsVector(b_values_t->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Inputs a_values and b_values should be vectors "
|
||||
"but received shapes: ",
|
||||
a_values_t->shape().DebugString(), " and ",
|
||||
b_values_t->shape().DebugString()));
|
||||
|
||||
const int64 a_nnz = a_indices_t->dim_size(0);
|
||||
const int64 b_nnz = b_indices_t->dim_size(0);
|
||||
const auto a_values = a_values_t->vec<T>();
|
||||
const auto b_values = b_values_t->vec<T>();
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, a_values.size() == a_nnz && b_values.size() == b_nnz,
|
||||
errors::InvalidArgument("Expected ", a_nnz, " and ", b_nnz,
|
||||
" non-empty input values, got ",
|
||||
a_values.size(), " and ", b_values.size()));
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape_t->shape()) &&
|
||||
TensorShapeUtils::IsVector(b_shape_t->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Input shapes should be a vector but received shapes ",
|
||||
a_shape_t->shape().DebugString(), " and ",
|
||||
b_shape_t->shape().DebugString()));
|
||||
OP_REQUIRES(ctx, a_shape_t->IsSameSize(*b_shape_t),
|
||||
errors::InvalidArgument(
|
||||
"Operands do not have the same ranks; got shapes: ",
|
||||
a_shape_t->SummarizeValue(10), " and ",
|
||||
b_shape_t->SummarizeValue(10)));
|
||||
const auto a_shape = a_shape_t->flat<int64>();
|
||||
const auto b_shape = b_shape_t->flat<int64>();
|
||||
for (int i = 0; i < a_shape_t->NumElements(); ++i) {
|
||||
OP_REQUIRES(ctx, a_shape(i) == b_shape(i),
|
||||
errors::InvalidArgument("Operands' shapes do not match: got ",
|
||||
a_shape(i), " and ", b_shape(i),
|
||||
" for dimension ", i));
|
||||
}
|
||||
|
||||
const int num_dims = a_indices_t->dim_size(1);
|
||||
const auto a_indices_mat = a_indices_t->matrix<int64>();
|
||||
const auto b_indices_mat = b_indices_t->matrix<int64>();
|
||||
std::vector<T> a_augmented_values, b_augmented_values;
|
||||
std::vector<std::pair<bool, int64>> entries_to_copy; // from_a?, idx
|
||||
UnionSparseIndicesAndValues(a_indices_mat, a_values, a_nnz, b_indices_mat,
|
||||
b_values, b_nnz, num_dims, &a_augmented_values,
|
||||
&b_augmented_values, &entries_to_copy);
|
||||
|
||||
// Allocates and fills output tensors.
|
||||
const int64 sum_nnz = a_augmented_values.size();
|
||||
Tensor *output_indices_t, *output_values_t;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({sum_nnz, num_dims}),
|
||||
&output_indices_t));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(1, TensorShape({sum_nnz}), &output_values_t));
|
||||
auto output_indices_mat = output_indices_t->matrix<int64>();
|
||||
|
||||
for (int64 i = 0; i < sum_nnz; ++i) {
|
||||
const bool from_a = entries_to_copy[i].first;
|
||||
const int64 idx = entries_to_copy[i].second;
|
||||
output_indices_mat.chip<0>(i) =
|
||||
from_a ? a_indices_mat.chip<0>(idx) : b_indices_mat.chip<0>(idx);
|
||||
}
|
||||
|
||||
// Performs the functor operation using Eigen.
|
||||
using TensorMap =
|
||||
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned>;
|
||||
auto a_augmented_values_t = TensorMap(a_augmented_values.data(), sum_nnz);
|
||||
auto b_augmented_values_t = TensorMap(b_augmented_values.data(), sum_nnz);
|
||||
output_values_t->flat<T>().device(ctx->eigen_device<Device>()) =
|
||||
a_augmented_values_t.binaryExpr(b_augmented_values_t,
|
||||
typename Functor::func());
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSparseMinimum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
SparseSparseBinaryOpShared<CPUDevice, T, functor::minimum<T>>) \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSparseMaximum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
SparseSparseBinaryOpShared<CPUDevice, T, functor::maximum<T>>)
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
} // namespace tensorflow
|
@ -570,4 +570,58 @@ sp_shape: 1-D. Shape of the input SparseTensor.
|
||||
output: 1-D. The `NNZ` values for the result `SparseTensor`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SparseSparseMaximum")
|
||||
.Input("a_indices: int64")
|
||||
.Input("a_values: T")
|
||||
.Input("a_shape: int64")
|
||||
.Input("b_indices: int64")
|
||||
.Input("b_values: T")
|
||||
.Input("b_shape: int64")
|
||||
.Output("output_indices: int64")
|
||||
.Output("output_values: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.Doc(R"doc(
|
||||
Returns the element-wise max of two SparseTensors.
|
||||
|
||||
Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
|
||||
|
||||
a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
|
||||
SparseTensor, in the canonical lexicographic ordering.
|
||||
a_values: 1-D. `N` non-empty values corresponding to `a_indices`.
|
||||
a_shape: 1-D. Shape of the input SparseTensor.
|
||||
b_indices: counterpart to `a_indices` for the other operand.
|
||||
b_values: counterpart to `a_values` for the other operand; must be of the same dtype.
|
||||
b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal.
|
||||
|
||||
output_indices: 2-D. The indices of the output SparseTensor.
|
||||
output_values: 1-D. The values of the output SparseTensor.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SparseSparseMinimum")
|
||||
.Input("a_indices: int64")
|
||||
.Input("a_values: T")
|
||||
.Input("a_shape: int64")
|
||||
.Input("b_indices: int64")
|
||||
.Input("b_values: T")
|
||||
.Input("b_shape: int64")
|
||||
.Output("output_indices: int64")
|
||||
.Output("output_values: T")
|
||||
.Attr("T: numbertype")
|
||||
.Doc(R"doc(
|
||||
Returns the element-wise min of two SparseTensors.
|
||||
|
||||
Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
|
||||
|
||||
a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
|
||||
SparseTensor, in the canonical lexicographic ordering.
|
||||
a_values: 1-D. `N` non-empty values corresponding to `a_indices`.
|
||||
a_shape: 1-D. Shape of the input SparseTensor.
|
||||
b_indices: counterpart to `a_indices` for the other operand.
|
||||
b_values: counterpart to `a_values` for the other operand; must be of the same dtype.
|
||||
b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal.
|
||||
|
||||
output_indices: 2-D. The indices of the output SparseTensor.
|
||||
output_values: 1-D. The values of the output SparseTensor.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -649,5 +649,68 @@ class SparseSoftmaxTest(test_util.TensorFlowTestCase):
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
|
||||
class SparseMinimumMaximumTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _assertSparseTensorValueEqual(self, a, b):
|
||||
self.assertAllEqual(a.indices, b.indices)
|
||||
self.assertAllEqual(a.values, b.values)
|
||||
self.assertAllEqual(a.shape, b.shape)
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session(use_gpu=False):
|
||||
# 1-D, values at index 0.
|
||||
sp_zero = ops.SparseTensor([[0]], [0], [7])
|
||||
sp_one = ops.SparseTensor([[0]], [1], [7])
|
||||
max_tf = tf.sparse_maximum(sp_zero, sp_one).eval()
|
||||
min_tf = tf.sparse_minimum(sp_zero, sp_one).eval()
|
||||
self._assertSparseTensorValueEqual(sp_one.eval(), max_tf)
|
||||
self._assertSparseTensorValueEqual(sp_zero.eval(), min_tf)
|
||||
|
||||
# Values at different indices.
|
||||
sp_zero = ops.SparseTensor([[0]], [0], [7])
|
||||
sp_zero_2 = ops.SparseTensor([[1]], [0], [7])
|
||||
expected = ops.SparseTensor([[0], [1]], [0, 0], [7])
|
||||
max_tf = tf.sparse_maximum(sp_zero, sp_zero_2).eval()
|
||||
min_tf = tf.sparse_minimum(sp_zero, sp_zero_2).eval()
|
||||
self._assertSparseTensorValueEqual(expected.eval(), max_tf)
|
||||
self._assertSparseTensorValueEqual(expected.eval(), min_tf)
|
||||
|
||||
def testRandom(self):
|
||||
np.random.seed(1618)
|
||||
shapes = [(13,), (6, 8), (1, 7, 1)]
|
||||
for shape in shapes:
|
||||
for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
|
||||
a_np = np.random.randn(*shape).astype(dtype)
|
||||
b_np = np.random.randn(*shape).astype(dtype)
|
||||
sp_a, unused_a_nnz = _sparsify(a_np, thresh=-.5)
|
||||
sp_b, unused_b_nnz = _sparsify(b_np, thresh=-.5)
|
||||
|
||||
with self.test_session(use_gpu=False):
|
||||
maximum_tf = tf.sparse_maximum(sp_a, sp_b)
|
||||
maximum_tf_densified = tf.sparse_tensor_to_dense(maximum_tf).eval()
|
||||
minimum_tf = tf.sparse_minimum(sp_a, sp_b)
|
||||
minimum_tf_densified = tf.sparse_tensor_to_dense(minimum_tf).eval()
|
||||
|
||||
a_densified = tf.sparse_tensor_to_dense(sp_a).eval()
|
||||
b_densified = tf.sparse_tensor_to_dense(sp_b).eval()
|
||||
|
||||
self.assertAllEqual(np.maximum(a_densified, b_densified),
|
||||
maximum_tf_densified)
|
||||
self.assertAllEqual(np.minimum(a_densified, b_densified),
|
||||
minimum_tf_densified)
|
||||
|
||||
def testMismatchedShapes(self):
|
||||
with self.test_session(use_gpu=False):
|
||||
sp_zero = ops.SparseTensor([[0, 0]], [0], [1, 1])
|
||||
sp_one = ops.SparseTensor([[0]], [1], [2])
|
||||
with self.assertRaisesOpError("Operands do not have the same ranks"):
|
||||
tf.sparse_maximum(sp_zero, sp_one).eval()
|
||||
|
||||
sp_zero = ops.SparseTensor([[0]], [0], [1])
|
||||
sp_one = ops.SparseTensor([[0]], [1], [2])
|
||||
with self.assertRaisesOpError("Operands' shapes do not match"):
|
||||
tf.sparse_maximum(sp_zero, sp_one).eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -256,3 +256,15 @@ def _SparseSoftmaxGrad(op, grad):
|
||||
|
||||
grad_x = sp_sum.values * sp_output.values
|
||||
return [None, grad_x, None]
|
||||
|
||||
|
||||
@ops.RegisterGradient("SparseSparseMaximum")
|
||||
def _SparseSparseMaximumGrad(unused_op, unused_grad):
|
||||
raise NotImplementedError("Gradient for SparseSparseMaximum is currently not"
|
||||
" implemented yet.")
|
||||
|
||||
|
||||
@ops.RegisterGradient("SparseSparseMinimum")
|
||||
def _SparseSparseMinimumGrad(unused_op, unused_grad):
|
||||
raise NotImplementedError("Gradient for SparseSparseMinimum is currently not"
|
||||
" implemented yet.")
|
||||
|
@ -48,6 +48,8 @@ dimension, and dense along all other dimensions.
|
||||
@@sparse_add
|
||||
@@sparse_softmax
|
||||
@@sparse_tensor_dense_matmul
|
||||
@@sparse_maximum
|
||||
@@sparse_minimum
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -1487,3 +1489,84 @@ def _SparseSoftmaxShape(op): # pylint: disable=invalid-name
|
||||
unused_shape_shape = op.inputs[2].get_shape().with_rank(1)
|
||||
nnz = values_shape[0]
|
||||
return [tensor_shape.vector(nnz)]
|
||||
|
||||
|
||||
def sparse_maximum(sp_a, sp_b, name=None):
|
||||
"""Returns the element-wise max of two SparseTensors.
|
||||
|
||||
Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
|
||||
Example:
|
||||
|
||||
```python
|
||||
sp_zero = ops.SparseTensor([[0]], [0], [7])
|
||||
sp_one = ops.SparseTensor([[1]], [1], [7])
|
||||
res = tf.sparse_maximum(sp_zero, sp_one).eval()
|
||||
# "res" should be equal to SparseTensor([[0], [1]], [0, 1], [7]).
|
||||
```
|
||||
|
||||
Args:
|
||||
sp_a: a `SparseTensor` operand whose dtype is real, and indices
|
||||
lexicographically ordered.
|
||||
sp_b: the other `SparseTensor` operand with the same requirements (and the
|
||||
same shape).
|
||||
name: optional name of the operation.
|
||||
Returns:
|
||||
output: the output SparseTensor.
|
||||
"""
|
||||
with ops.op_scope([sp_a.indices, sp_a.values, sp_b.indices, sp_b.values],
|
||||
name, "SparseSparseMaximum") as name:
|
||||
out_indices, out_values = gen_sparse_ops.sparse_sparse_maximum(sp_a.indices,
|
||||
sp_a.values,
|
||||
sp_a.shape,
|
||||
sp_b.indices,
|
||||
sp_b.values,
|
||||
sp_b.shape,
|
||||
name=name)
|
||||
return ops.SparseTensor(out_indices, out_values, sp_a.shape)
|
||||
|
||||
|
||||
def sparse_minimum(sp_a, sp_b, name=None):
|
||||
"""Returns the element-wise min of two SparseTensors.
|
||||
|
||||
Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
|
||||
Example:
|
||||
|
||||
```python
|
||||
sp_zero = ops.SparseTensor([[0]], [0], [7])
|
||||
sp_one = ops.SparseTensor([[1]], [1], [7])
|
||||
res = tf.sparse_minimum(sp_zero, sp_one).eval()
|
||||
# "res" should be equal to SparseTensor([[0], [1]], [0, 0], [7]).
|
||||
```
|
||||
|
||||
Args:
|
||||
sp_a: a `SparseTensor` operand whose dtype is real, and indices
|
||||
lexicographically ordered.
|
||||
sp_b: the other `SparseTensor` operand with the same requirements (and the
|
||||
same shape).
|
||||
name: optional name of the operation.
|
||||
Returns:
|
||||
output: the output SparseTensor.
|
||||
"""
|
||||
with ops.op_scope([sp_a.indices, sp_a.values, sp_b.indices, sp_b.values],
|
||||
name, "SparseSparseMinimum") as name:
|
||||
out_indices, out_values = gen_sparse_ops.sparse_sparse_minimum(sp_a.indices,
|
||||
sp_a.values,
|
||||
sp_a.shape,
|
||||
sp_b.indices,
|
||||
sp_b.values,
|
||||
sp_b.shape,
|
||||
name=name)
|
||||
return ops.SparseTensor(out_indices, out_values, sp_a.shape)
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseSparseMaximum")
|
||||
@ops.RegisterShape("SparseSparseMinimum")
|
||||
def _SparseSparseMaximumMinimumShape(op): # pylint: disable=invalid-name
|
||||
"""Shape function for SparseSparseMaximum and SparseSparseMinimum."""
|
||||
op.inputs[0].get_shape().assert_has_rank(2) # a_indices
|
||||
op.inputs[1].get_shape().assert_has_rank(1) # a_values
|
||||
op.inputs[2].get_shape().assert_has_rank(1) # a_shape
|
||||
op.inputs[3].get_shape().assert_has_rank(2) # b_indices
|
||||
op.inputs[4].get_shape().assert_has_rank(1) # b_values
|
||||
op.inputs[5].get_shape().assert_has_rank(1) # b_shape
|
||||
return [tensor_shape.unknown_shape(2), tensor_shape.unknown_shape(1)]
|
||||
|
Loading…
Reference in New Issue
Block a user