From d04c05def547a84bdc78f04e44e69c6f634b6df0 Mon Sep 17 00:00:00 2001
From: Zongheng Yang <zongheng@google.com>
Date: Thu, 30 Jun 2016 14:04:12 -0800
Subject: [PATCH] Support Sparse-Sparse cwise ops; use for
 tf.sparse_{minimum,maximum}().

This change adds the CPU kernel and Python ifaces.  For now, assumes both
operands have the same shapes.
Change: 126348349
---
 tensorflow/core/kernels/BUILD                 |   1 +
 tensorflow/core/kernels/sparse_add_op.cc      |  35 +--
 .../kernels/sparse_sparse_binary_op_shared.cc | 230 ++++++++++++++++++
 tensorflow/core/ops/sparse_ops.cc             |  54 ++++
 .../python/kernel_tests/sparse_ops_test.py    |  63 +++++
 tensorflow/python/ops/sparse_grad.py          |  12 +
 tensorflow/python/ops/sparse_ops.py           |  83 +++++++
 7 files changed, 461 insertions(+), 17 deletions(-)
 create mode 100644 tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc

diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 142f63c6b47..06c3c86c673 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1562,6 +1562,7 @@ tf_kernel_libraries(
         "sparse_concat_op",
         "sparse_reduce_sum_op",
         "sparse_dense_binary_op_shared",
+        "sparse_sparse_binary_op_shared",
         "sparse_reorder_op",
         "sparse_reshape_op",
         "sparse_softmax",
diff --git a/tensorflow/core/kernels/sparse_add_op.cc b/tensorflow/core/kernels/sparse_add_op.cc
index 0cb77d785ad..bd91dfdce64 100644
--- a/tensorflow/core/kernels/sparse_add_op.cc
+++ b/tensorflow/core/kernels/sparse_add_op.cc
@@ -54,31 +54,32 @@ class SparseAddOp : public OpKernel {
                     b_values_t->shape().DebugString()));
     auto a_values = ctx->input(1).vec<T>();
     auto b_values = ctx->input(4).vec<T>();
-
-    OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
-    OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape));
-    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()) &&
-                         TensorShapeUtils::IsVector(b_shape->shape()),
-                errors::InvalidArgument(
-                    "Input shape should be a vector but received shapes ",
-                    a_shape->shape().DebugString(), " and ",
-                    b_shape->shape().DebugString()));
-
     OP_REQUIRES(
         ctx, a_values.size() == a_nnz && b_values.size() == b_nnz,
         errors::InvalidArgument("Expected ", a_nnz, " and ", b_nnz,
                                 " non-empty input values, got ",
                                 a_values.size(), " and ", b_values.size()));
 
-    OP_REQUIRES(ctx, a_shape->dims() == b_shape->dims(),
+    OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
+    OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape));
+    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()) &&
+                         TensorShapeUtils::IsVector(b_shape->shape()),
                 errors::InvalidArgument(
-                    "Ranks of input tensors must match, but saw ranks: ",
-                    a_shape->dims(), " and ", b_shape->dims()));
-    for (int i = 0; i < a_shape->dims(); ++i) {
-      OP_REQUIRES(ctx, a_shape->dim_size(i) == b_shape->dim_size(i),
+                    "Input shapes should be a vector but received shapes ",
+                    a_shape->shape().DebugString(), " and ",
+                    b_shape->shape().DebugString()));
+    OP_REQUIRES(
+        ctx, a_shape->IsSameSize(*b_shape),
+        errors::InvalidArgument(
+            "Operands do not have the same ranks; got shapes: ",
+            a_shape->SummarizeValue(10), " and ", b_shape->SummarizeValue(10)));
+    const auto a_shape_flat = a_shape->flat<int64>();
+    const auto b_shape_flat = b_shape->flat<int64>();
+    for (int i = 0; i < a_shape->NumElements(); ++i) {
+      OP_REQUIRES(ctx, a_shape_flat(i) == b_shape_flat(i),
                   errors::InvalidArgument(
-                      "Input shapes must match: got ", a_shape->dim_size(i),
-                      " and ", b_shape->dim_size(i), " for dimension ", i));
+                      "Operands' shapes do not match: got ", a_shape_flat(i),
+                      " and ", b_shape_flat(i), " for dimension ", i));
     }
 
     OP_REQUIRES_OK(ctx, ctx->input("thresh", &thresh_t));
diff --git a/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
new file mode 100644
index 00000000000..a8a116ebf58
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
@@ -0,0 +1,230 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// SparseSparseBinaryOpShared is the shared code for binary coefficient-wise
+// (cwise) operations of the following form:
+//
+//   sparse_t <binary cwise op> sparse_t -> new sparse_t
+//
+// The output SparseTensor may store up to "a_nnz + b_nnz" elements.
+
+// IMPLEMENTATION DETAILS (not part of the interface specification).
+//
+// This kernel implements the "union" semantics on the non-zeros: namely, any
+// non-zero from either side participate in the calculations, and any resultant
+// zeros will NOT be excluded from the output storage.
+//
+// (In the future, we could always add a pruning op the prunes away the zeros,
+// if desirable.)
+
+// See docs of all registered ops in ../ops/sparse_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+namespace {
+// Unions the sparse indices and outputs corresponding values: namely, if a
+// non-zero appear in one side, it will participate in the calculation, where
+// the counterpart on the other side is either a value or an implicit zero.
+//
+// On exit, outputs the augmented values in "{a,b}_augmented_values", and fills
+// "entries_to_copy" with "(from_a?, index)" pairs.  All three vectors have the
+// same size.
+//
+// The input and output sparse tensors are assumed ordered in the canonical
+// row-major order.
+template <typename T>
+void UnionSparseIndicesAndValues(
+    typename TTypes<int64>::ConstMatrix a_indices_mat,
+    typename TTypes<T>::ConstFlat a_values, int64 a_nnz,
+    typename TTypes<int64>::ConstMatrix b_indices_mat,
+    typename TTypes<T>::ConstFlat b_values, int64 b_nnz, int num_dims,
+    std::vector<T> *a_augmented_values, std::vector<T> *b_augmented_values,
+    std::vector<std::pair<bool, int64>> *entries_to_copy) {
+  entries_to_copy->reserve(a_nnz + b_nnz);
+  a_augmented_values->reserve(a_nnz);
+  b_augmented_values->reserve(b_nnz);
+
+  int64 i = 0, j = 0;
+  const T kZero = T(0);
+  while (i < a_nnz && j < b_nnz) {
+    switch (sparse::DimComparator::cmp(a_indices_mat, b_indices_mat, i, j,
+                                       num_dims)) {
+      case -1:
+        entries_to_copy->emplace_back(true, i);
+        a_augmented_values->push_back(a_values(i));
+        b_augmented_values->push_back(kZero);
+        ++i;
+        break;
+      case 0:
+        entries_to_copy->emplace_back(true, i);
+        a_augmented_values->push_back(a_values(i));
+        b_augmented_values->push_back(b_values(j));
+        ++i;
+        ++j;
+        break;
+      case 1:
+        entries_to_copy->emplace_back(false, j);
+        a_augmented_values->push_back(kZero);
+        b_augmented_values->push_back(b_values(j));
+        ++j;
+        break;
+    }
+  }
+  // Handles leftovers; at most one loop runs.
+  while (i < a_nnz) {
+    entries_to_copy->emplace_back(/* is_a */ true, i);
+    a_augmented_values->push_back(a_values(i++));
+    b_augmented_values->push_back(kZero);
+  }
+  while (j < b_nnz) {
+    entries_to_copy->emplace_back(/* is_a */ false, j);
+    a_augmented_values->push_back(kZero);
+    b_augmented_values->push_back(b_values(j++));
+  }
+}
+}  // anonymous namespace
+
+// Device: CPUDevice.  GPU kernel is not supported currently.
+// T: dtype of the SparseTensor's.
+// Functor: binary cwise operation to perform on the corresponding operand
+// values.  See cwise_ops.h for a list of possible functors to register with.
+template <typename Device, typename T, typename Functor>
+class SparseSparseBinaryOpShared : public OpKernel {
+ public:
+  explicit SparseSparseBinaryOpShared(OpKernelConstruction *ctx)
+      : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext *ctx) override {
+    const Tensor *a_indices_t, *a_values_t, *a_shape_t, *b_indices_t,
+        *b_values_t, *b_shape_t;
+    OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices_t));
+    OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
+    OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape_t));
+    OP_REQUIRES_OK(ctx, ctx->input("b_indices", &b_indices_t));
+    OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t));
+    OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape_t));
+
+    // Validations.
+    OP_REQUIRES(
+        ctx, TensorShapeUtils::IsMatrix(a_indices_t->shape()) &&
+                 TensorShapeUtils::IsMatrix(b_indices_t->shape()),
+        errors::InvalidArgument("Inputs a_indices and b_indices should be "
+                                "matrices but received shapes: ",
+                                a_indices_t->shape().DebugString(), ", ",
+                                b_indices_t->shape().DebugString()));
+    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values_t->shape()) &&
+                         TensorShapeUtils::IsVector(b_values_t->shape()),
+                errors::InvalidArgument(
+                    "Inputs a_values and b_values should be vectors "
+                    "but received shapes: ",
+                    a_values_t->shape().DebugString(), " and ",
+                    b_values_t->shape().DebugString()));
+
+    const int64 a_nnz = a_indices_t->dim_size(0);
+    const int64 b_nnz = b_indices_t->dim_size(0);
+    const auto a_values = a_values_t->vec<T>();
+    const auto b_values = b_values_t->vec<T>();
+
+    OP_REQUIRES(
+        ctx, a_values.size() == a_nnz && b_values.size() == b_nnz,
+        errors::InvalidArgument("Expected ", a_nnz, " and ", b_nnz,
+                                " non-empty input values, got ",
+                                a_values.size(), " and ", b_values.size()));
+
+    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape_t->shape()) &&
+                         TensorShapeUtils::IsVector(b_shape_t->shape()),
+                errors::InvalidArgument(
+                    "Input shapes should be a vector but received shapes ",
+                    a_shape_t->shape().DebugString(), " and ",
+                    b_shape_t->shape().DebugString()));
+    OP_REQUIRES(ctx, a_shape_t->IsSameSize(*b_shape_t),
+                errors::InvalidArgument(
+                    "Operands do not have the same ranks; got shapes: ",
+                    a_shape_t->SummarizeValue(10), " and ",
+                    b_shape_t->SummarizeValue(10)));
+    const auto a_shape = a_shape_t->flat<int64>();
+    const auto b_shape = b_shape_t->flat<int64>();
+    for (int i = 0; i < a_shape_t->NumElements(); ++i) {
+      OP_REQUIRES(ctx, a_shape(i) == b_shape(i),
+                  errors::InvalidArgument("Operands' shapes do not match: got ",
+                                          a_shape(i), " and ", b_shape(i),
+                                          " for dimension ", i));
+    }
+
+    const int num_dims = a_indices_t->dim_size(1);
+    const auto a_indices_mat = a_indices_t->matrix<int64>();
+    const auto b_indices_mat = b_indices_t->matrix<int64>();
+    std::vector<T> a_augmented_values, b_augmented_values;
+    std::vector<std::pair<bool, int64>> entries_to_copy;  // from_a?, idx
+    UnionSparseIndicesAndValues(a_indices_mat, a_values, a_nnz, b_indices_mat,
+                                b_values, b_nnz, num_dims, &a_augmented_values,
+                                &b_augmented_values, &entries_to_copy);
+
+    // Allocates and fills output tensors.
+    const int64 sum_nnz = a_augmented_values.size();
+    Tensor *output_indices_t, *output_values_t;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output(0, TensorShape({sum_nnz, num_dims}),
+                                        &output_indices_t));
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output(1, TensorShape({sum_nnz}), &output_values_t));
+    auto output_indices_mat = output_indices_t->matrix<int64>();
+
+    for (int64 i = 0; i < sum_nnz; ++i) {
+      const bool from_a = entries_to_copy[i].first;
+      const int64 idx = entries_to_copy[i].second;
+      output_indices_mat.chip<0>(i) =
+          from_a ? a_indices_mat.chip<0>(idx) : b_indices_mat.chip<0>(idx);
+    }
+
+    // Performs the functor operation using Eigen.
+    using TensorMap =
+        Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+                         Eigen::Aligned>;
+    auto a_augmented_values_t = TensorMap(a_augmented_values.data(), sum_nnz);
+    auto b_augmented_values_t = TensorMap(b_augmented_values.data(), sum_nnz);
+    output_values_t->flat<T>().device(ctx->eigen_device<Device>()) =
+        a_augmented_values_t.binaryExpr(b_augmented_values_t,
+                                        typename Functor::func());
+  }
+};
+
+#define REGISTER_KERNELS(T)                                                  \
+  REGISTER_KERNEL_BUILDER(                                                   \
+      Name("SparseSparseMinimum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      SparseSparseBinaryOpShared<CPUDevice, T, functor::minimum<T>>)         \
+                                                                             \
+  REGISTER_KERNEL_BUILDER(                                                   \
+      Name("SparseSparseMaximum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      SparseSparseBinaryOpShared<CPUDevice, T, functor::maximum<T>>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index d80a5b1f9bc..a39f2f70cba 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -570,4 +570,58 @@ sp_shape: 1-D.  Shape of the input SparseTensor.
 output: 1-D.  The `NNZ` values for the result `SparseTensor`.
 )doc");
 
+REGISTER_OP("SparseSparseMaximum")
+    .Input("a_indices: int64")
+    .Input("a_values: T")
+    .Input("a_shape: int64")
+    .Input("b_indices: int64")
+    .Input("b_values: T")
+    .Input("b_shape: int64")
+    .Output("output_indices: int64")
+    .Output("output_values: T")
+    .Attr("T: realnumbertype")
+    .Doc(R"doc(
+Returns the element-wise max of two SparseTensors.
+
+Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
+
+a_indices: 2-D.  `N x R` matrix with the indices of non-empty values in a
+  SparseTensor, in the canonical lexicographic ordering.
+a_values: 1-D.  `N` non-empty values corresponding to `a_indices`.
+a_shape: 1-D.  Shape of the input SparseTensor.
+b_indices: counterpart to `a_indices` for the other operand.
+b_values: counterpart to `a_values` for the other operand; must be of the same dtype.
+b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal.
+
+output_indices: 2-D.  The indices of the output SparseTensor.
+output_values: 1-D.  The values of the output SparseTensor.
+)doc");
+
+REGISTER_OP("SparseSparseMinimum")
+    .Input("a_indices: int64")
+    .Input("a_values: T")
+    .Input("a_shape: int64")
+    .Input("b_indices: int64")
+    .Input("b_values: T")
+    .Input("b_shape: int64")
+    .Output("output_indices: int64")
+    .Output("output_values: T")
+    .Attr("T: numbertype")
+    .Doc(R"doc(
+Returns the element-wise min of two SparseTensors.
+
+Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
+
+a_indices: 2-D.  `N x R` matrix with the indices of non-empty values in a
+  SparseTensor, in the canonical lexicographic ordering.
+a_values: 1-D.  `N` non-empty values corresponding to `a_indices`.
+a_shape: 1-D.  Shape of the input SparseTensor.
+b_indices: counterpart to `a_indices` for the other operand.
+b_values: counterpart to `a_values` for the other operand; must be of the same dtype.
+b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal.
+
+output_indices: 2-D.  The indices of the output SparseTensor.
+output_values: 1-D.  The values of the output SparseTensor.
+)doc");
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index 867bfc5b369..a0394851d11 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -649,5 +649,68 @@ class SparseSoftmaxTest(test_util.TensorFlowTestCase):
         self.assertLess(err, 1e-4)
 
 
+class SparseMinimumMaximumTest(test_util.TensorFlowTestCase):
+
+  def _assertSparseTensorValueEqual(self, a, b):
+    self.assertAllEqual(a.indices, b.indices)
+    self.assertAllEqual(a.values, b.values)
+    self.assertAllEqual(a.shape, b.shape)
+
+  def testBasic(self):
+    with self.test_session(use_gpu=False):
+      # 1-D, values at index 0.
+      sp_zero = ops.SparseTensor([[0]], [0], [7])
+      sp_one = ops.SparseTensor([[0]], [1], [7])
+      max_tf = tf.sparse_maximum(sp_zero, sp_one).eval()
+      min_tf = tf.sparse_minimum(sp_zero, sp_one).eval()
+      self._assertSparseTensorValueEqual(sp_one.eval(), max_tf)
+      self._assertSparseTensorValueEqual(sp_zero.eval(), min_tf)
+
+      # Values at different indices.
+      sp_zero = ops.SparseTensor([[0]], [0], [7])
+      sp_zero_2 = ops.SparseTensor([[1]], [0], [7])
+      expected = ops.SparseTensor([[0], [1]], [0, 0], [7])
+      max_tf = tf.sparse_maximum(sp_zero, sp_zero_2).eval()
+      min_tf = tf.sparse_minimum(sp_zero, sp_zero_2).eval()
+      self._assertSparseTensorValueEqual(expected.eval(), max_tf)
+      self._assertSparseTensorValueEqual(expected.eval(), min_tf)
+
+  def testRandom(self):
+    np.random.seed(1618)
+    shapes = [(13,), (6, 8), (1, 7, 1)]
+    for shape in shapes:
+      for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+        a_np = np.random.randn(*shape).astype(dtype)
+        b_np = np.random.randn(*shape).astype(dtype)
+        sp_a, unused_a_nnz = _sparsify(a_np, thresh=-.5)
+        sp_b, unused_b_nnz = _sparsify(b_np, thresh=-.5)
+
+        with self.test_session(use_gpu=False):
+          maximum_tf = tf.sparse_maximum(sp_a, sp_b)
+          maximum_tf_densified = tf.sparse_tensor_to_dense(maximum_tf).eval()
+          minimum_tf = tf.sparse_minimum(sp_a, sp_b)
+          minimum_tf_densified = tf.sparse_tensor_to_dense(minimum_tf).eval()
+
+          a_densified = tf.sparse_tensor_to_dense(sp_a).eval()
+          b_densified = tf.sparse_tensor_to_dense(sp_b).eval()
+
+        self.assertAllEqual(np.maximum(a_densified, b_densified),
+                            maximum_tf_densified)
+        self.assertAllEqual(np.minimum(a_densified, b_densified),
+                            minimum_tf_densified)
+
+  def testMismatchedShapes(self):
+    with self.test_session(use_gpu=False):
+      sp_zero = ops.SparseTensor([[0, 0]], [0], [1, 1])
+      sp_one = ops.SparseTensor([[0]], [1], [2])
+      with self.assertRaisesOpError("Operands do not have the same ranks"):
+        tf.sparse_maximum(sp_zero, sp_one).eval()
+
+      sp_zero = ops.SparseTensor([[0]], [0], [1])
+      sp_one = ops.SparseTensor([[0]], [1], [2])
+      with self.assertRaisesOpError("Operands' shapes do not match"):
+        tf.sparse_maximum(sp_zero, sp_one).eval()
+
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index 93c026f2471..57350253ab8 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -256,3 +256,15 @@ def _SparseSoftmaxGrad(op, grad):
 
   grad_x = sp_sum.values * sp_output.values
   return [None, grad_x, None]
+
+
+@ops.RegisterGradient("SparseSparseMaximum")
+def _SparseSparseMaximumGrad(unused_op, unused_grad):
+  raise NotImplementedError("Gradient for SparseSparseMaximum is currently not"
+                            " implemented yet.")
+
+
+@ops.RegisterGradient("SparseSparseMinimum")
+def _SparseSparseMinimumGrad(unused_op, unused_grad):
+  raise NotImplementedError("Gradient for SparseSparseMinimum is currently not"
+                            " implemented yet.")
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index e0a922739bc..b5877d27423 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -48,6 +48,8 @@ dimension, and dense along all other dimensions.
 @@sparse_add
 @@sparse_softmax
 @@sparse_tensor_dense_matmul
+@@sparse_maximum
+@@sparse_minimum
 """
 from __future__ import absolute_import
 from __future__ import division
@@ -1487,3 +1489,84 @@ def _SparseSoftmaxShape(op):  # pylint: disable=invalid-name
   unused_shape_shape = op.inputs[2].get_shape().with_rank(1)
   nnz = values_shape[0]
   return [tensor_shape.vector(nnz)]
+
+
+def sparse_maximum(sp_a, sp_b, name=None):
+  """Returns the element-wise max of two SparseTensors.
+
+  Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
+  Example:
+
+  ```python
+  sp_zero = ops.SparseTensor([[0]], [0], [7])
+  sp_one = ops.SparseTensor([[1]], [1], [7])
+  res = tf.sparse_maximum(sp_zero, sp_one).eval()
+  # "res" should be equal to SparseTensor([[0], [1]], [0, 1], [7]).
+  ```
+
+  Args:
+    sp_a: a `SparseTensor` operand whose dtype is real, and indices
+      lexicographically ordered.
+    sp_b: the other `SparseTensor` operand with the same requirements (and the
+      same shape).
+    name: optional name of the operation.
+  Returns:
+    output: the output SparseTensor.
+  """
+  with ops.op_scope([sp_a.indices, sp_a.values, sp_b.indices, sp_b.values],
+                    name, "SparseSparseMaximum") as name:
+    out_indices, out_values = gen_sparse_ops.sparse_sparse_maximum(sp_a.indices,
+                                                                   sp_a.values,
+                                                                   sp_a.shape,
+                                                                   sp_b.indices,
+                                                                   sp_b.values,
+                                                                   sp_b.shape,
+                                                                   name=name)
+  return ops.SparseTensor(out_indices, out_values, sp_a.shape)
+
+
+def sparse_minimum(sp_a, sp_b, name=None):
+  """Returns the element-wise min of two SparseTensors.
+
+  Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
+  Example:
+
+  ```python
+  sp_zero = ops.SparseTensor([[0]], [0], [7])
+  sp_one = ops.SparseTensor([[1]], [1], [7])
+  res = tf.sparse_minimum(sp_zero, sp_one).eval()
+  # "res" should be equal to SparseTensor([[0], [1]], [0, 0], [7]).
+  ```
+
+  Args:
+    sp_a: a `SparseTensor` operand whose dtype is real, and indices
+      lexicographically ordered.
+    sp_b: the other `SparseTensor` operand with the same requirements (and the
+      same shape).
+    name: optional name of the operation.
+  Returns:
+    output: the output SparseTensor.
+  """
+  with ops.op_scope([sp_a.indices, sp_a.values, sp_b.indices, sp_b.values],
+                    name, "SparseSparseMinimum") as name:
+    out_indices, out_values = gen_sparse_ops.sparse_sparse_minimum(sp_a.indices,
+                                                                   sp_a.values,
+                                                                   sp_a.shape,
+                                                                   sp_b.indices,
+                                                                   sp_b.values,
+                                                                   sp_b.shape,
+                                                                   name=name)
+  return ops.SparseTensor(out_indices, out_values, sp_a.shape)
+
+
+@ops.RegisterShape("SparseSparseMaximum")
+@ops.RegisterShape("SparseSparseMinimum")
+def _SparseSparseMaximumMinimumShape(op):  # pylint: disable=invalid-name
+  """Shape function for SparseSparseMaximum and SparseSparseMinimum."""
+  op.inputs[0].get_shape().assert_has_rank(2)  # a_indices
+  op.inputs[1].get_shape().assert_has_rank(1)  # a_values
+  op.inputs[2].get_shape().assert_has_rank(1)  # a_shape
+  op.inputs[3].get_shape().assert_has_rank(2)  # b_indices
+  op.inputs[4].get_shape().assert_has_rank(1)  # b_values
+  op.inputs[5].get_shape().assert_has_rank(1)  # b_shape
+  return [tensor_shape.unknown_shape(2), tensor_shape.unknown_shape(1)]