diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc index ac48202ada2..a4e89f439ed 100644 --- a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc +++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc @@ -88,12 +88,12 @@ class SparseDenseBinaryOpShared : public OpKernel { const auto rhs_dims = BCast::FromShape(dense_t->shape()); BCast b(lhs_dims, rhs_dims, false); // false for keeping the same num dims. - // True iff (size(lhs) > size(rhs)), or (sizes equal, lhs cwise rhs). + // True iff (size(lhs) >= size(rhs)) and all dims in lhs is greater or equal + // to dims in rhs (from right to left). auto VecGreaterEq = [](ArraySlice lhs, ArraySlice rhs) { - if (lhs.size() > rhs.size()) return true; if (lhs.size() < rhs.size()) return false; - for (size_t i = 0; i < lhs.size(); ++i) { - if (lhs[i] < rhs[i]) return false; + for (size_t i = 0; i < rhs.size(); ++i) { + if (lhs[lhs.size() - 1 - i] < rhs[rhs.size() - 1 - i]) return false; } return true; }; diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index 75f65e62517..7598991489c 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util @@ -798,6 +799,19 @@ class SparseMathOpsTest(test_util.TensorFlowTestCase): result_tensor.values).eval() self.assertAllEqual(result_np, res_densified) + @test_util.run_deprecated_v1 + def testCwiseShapeValidation(self): + # Test case for GitHub 24072. + with self.session(use_gpu=False): + a = array_ops.ones([3, 4, 1], dtype=dtypes.int32) + b = sparse_tensor.SparseTensor([[0, 0, 1, 0], [0, 0, 3, 0]], [10, 20], + [1, 1, 4, 2]) + c = a * b + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "broadcasts dense to sparse only; got incompatible shapes"): + c.eval() + @test_util.run_deprecated_v1 def testCwiseDivAndMul(self): np.random.seed(1618)