Allow sparse tensor with reorder work on potentially large dims,
Update: review comments addressed. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
f137ea2a23
commit
63690f568c
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||||
#include "tensorflow/core/util/overflow.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -55,21 +54,10 @@ class SparseReorderOp : public OpKernel {
|
|||||||
"Input shape should be a vector but received shape ",
|
"Input shape should be a vector but received shape ",
|
||||||
input_shape_in.shape().DebugString()));
|
input_shape_in.shape().DebugString()));
|
||||||
|
|
||||||
// Check if the sparse tensor input shape is valid
|
gtl::ArraySlice<int64> input_shape(
|
||||||
int64 total = 1;
|
input_shape_in.vec<int64>().data(), input_shape_in.NumElements());
|
||||||
for (int64 i = 0; i < input_shape_in.NumElements(); ++i) {
|
|
||||||
int dim = input_shape_in.vec<int64>()(i);
|
|
||||||
OP_REQUIRES(context, (dim >= 0),
|
|
||||||
errors::InvalidArgument("Dimension ", dim, " must be >= 0"));
|
|
||||||
total = MultiplyWithoutOverflow(total, dim);
|
|
||||||
OP_REQUIRES(context, (total > 0),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Shape would have more than 2**63 - 1 elements"));
|
|
||||||
}
|
|
||||||
|
|
||||||
const TensorShape input_shape(input_shape_in.vec<int64>());
|
gtl::InlinedVector<int64, 8> std_order(input_shape.size());
|
||||||
|
|
||||||
gtl::InlinedVector<int64, 8> std_order(input_shape.dims());
|
|
||||||
std::iota(std_order.begin(), std_order.end(), 0);
|
std::iota(std_order.begin(), std_order.end(), 0);
|
||||||
|
|
||||||
// Check if the sparse tensor is already ordered correctly
|
// Check if the sparse tensor is already ordered correctly
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -132,9 +131,9 @@ class SparseReorderTest(test.TestCase):
|
|||||||
dense_shape=[4096, 4096, 4096, 4096, 4096, 4096])
|
dense_shape=[4096, 4096, 4096, 4096, 4096, 4096])
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
(4096, 4096, 4096, 4096, 4096, 4096), sp_input.get_shape())
|
(4096, 4096, 4096, 4096, 4096, 4096), sp_input.get_shape())
|
||||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
sp_output = sparse_ops.sparse_reorder(sp_input)
|
||||||
"Shape would have more than"):
|
self.assertAllEqual(
|
||||||
sp_output = sparse_ops.sparse_reorder(sp_input)
|
(4096, 4096, 4096, 4096, 4096, 4096), sp_output.get_shape())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user