Merge pull request #24098 from yongtang:24072-sparse_dense_cwise-broadcasting
PiperOrigin-RevId: 224386209
This commit is contained in:
commit
6427950eae
@ -88,12 +88,12 @@ class SparseDenseBinaryOpShared : public OpKernel {
|
|||||||
const auto rhs_dims = BCast::FromShape(dense_t->shape());
|
const auto rhs_dims = BCast::FromShape(dense_t->shape());
|
||||||
BCast b(lhs_dims, rhs_dims, false); // false for keeping the same num dims.
|
BCast b(lhs_dims, rhs_dims, false); // false for keeping the same num dims.
|
||||||
|
|
||||||
// True iff (size(lhs) > size(rhs)), or (sizes equal, lhs cwise rhs).
|
// True iff (size(lhs) >= size(rhs)) and all dims in lhs is greater or equal
|
||||||
|
// to dims in rhs (from right to left).
|
||||||
auto VecGreaterEq = [](ArraySlice<int64> lhs, ArraySlice<int64> rhs) {
|
auto VecGreaterEq = [](ArraySlice<int64> lhs, ArraySlice<int64> rhs) {
|
||||||
if (lhs.size() > rhs.size()) return true;
|
|
||||||
if (lhs.size() < rhs.size()) return false;
|
if (lhs.size() < rhs.size()) return false;
|
||||||
for (size_t i = 0; i < lhs.size(); ++i) {
|
for (size_t i = 0; i < rhs.size(); ++i) {
|
||||||
if (lhs[i] < rhs[i]) return false;
|
if (lhs[lhs.size() - 1 - i] < rhs[rhs.size() - 1 - i]) return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -798,6 +799,19 @@ class SparseMathOpsTest(test_util.TensorFlowTestCase):
|
|||||||
result_tensor.values).eval()
|
result_tensor.values).eval()
|
||||||
self.assertAllEqual(result_np, res_densified)
|
self.assertAllEqual(result_np, res_densified)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testCwiseShapeValidation(self):
|
||||||
|
# Test case for GitHub 24072.
|
||||||
|
with self.session(use_gpu=False):
|
||||||
|
a = array_ops.ones([3, 4, 1], dtype=dtypes.int32)
|
||||||
|
b = sparse_tensor.SparseTensor([[0, 0, 1, 0], [0, 0, 3, 0]], [10, 20],
|
||||||
|
[1, 1, 4, 2])
|
||||||
|
c = a * b
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"broadcasts dense to sparse only; got incompatible shapes"):
|
||||||
|
c.eval()
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testCwiseDivAndMul(self):
|
def testCwiseDivAndMul(self):
|
||||||
np.random.seed(1618)
|
np.random.seed(1618)
|
||||||
|
Loading…
Reference in New Issue
Block a user