Fix SparseDenseCwise's broadcasting issue
This fix tries to address the issue raised in 24072. In `sparse_dense_cwise_mul/add` operations the broadcasting only support dense to sparse, though the validation was not captured. This fix fixes the validation in SparseDenseBinaryOpShared so that error could be thrown correctly. This fix fixes 24072. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
f23efde510
commit
8584f21392
@ -88,12 +88,11 @@ class SparseDenseBinaryOpShared : public OpKernel {
|
|||||||
const auto rhs_dims = BCast::FromShape(dense_t->shape());
|
const auto rhs_dims = BCast::FromShape(dense_t->shape());
|
||||||
BCast b(lhs_dims, rhs_dims, false); // false for keeping the same num dims.
|
BCast b(lhs_dims, rhs_dims, false); // false for keeping the same num dims.
|
||||||
|
|
||||||
// True iff (size(lhs) > size(rhs)), or (sizes equal, lhs cwise rhs).
|
// True iff (size(lhs) >= size(rhs)) and all dims in lhs is smaller or equal to dims in rhs (from right to left).
|
||||||
auto VecGreaterEq = [](ArraySlice<int64> lhs, ArraySlice<int64> rhs) {
|
auto VecGreaterEq = [](ArraySlice<int64> lhs, ArraySlice<int64> rhs) {
|
||||||
if (lhs.size() > rhs.size()) return true;
|
|
||||||
if (lhs.size() < rhs.size()) return false;
|
if (lhs.size() < rhs.size()) return false;
|
||||||
for (size_t i = 0; i < lhs.size(); ++i) {
|
for (size_t i = 0; i < rhs.size(); ++i) {
|
||||||
if (lhs[i] < rhs[i]) return false;
|
if (lhs[lhs.size() - 1 - i] < rhs[rhs.size() - 1 - i]) return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user