Address minor review comments for PR 46275
- Remove unused context parameter. - Rename functor::ReshapeSparseTensor to functor::ReshapeSparseTensorFunctor. - Make integer constants const.
This commit is contained in:
parent
1d8218f155
commit
b1a0dbdf4c
@ -36,15 +36,14 @@ using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
namespace functor {
|
||||
|
||||
template <>
|
||||
struct ReshapeSparseTensor<CPUDevice> {
|
||||
Status operator()(OpKernelContext *context, const TensorShape &input_shape,
|
||||
struct ReshapeSparseTensorFunctor<CPUDevice> {
|
||||
Status operator()(const TensorShape &input_shape,
|
||||
const TensorShape &output_shape,
|
||||
typename TTypes<int64>::ConstMatrix input_indices,
|
||||
typename TTypes<int64>::Matrix output_indices) const {
|
||||
(void)context;
|
||||
int64 input_rank = input_shape.dims();
|
||||
int64 output_rank = output_shape.dims();
|
||||
int64 nnz = input_indices.dimension(0);
|
||||
const int64 input_rank = input_shape.dims();
|
||||
const int64 output_rank = output_shape.dims();
|
||||
const int64 nnz = input_indices.dimension(0);
|
||||
gtl::InlinedVector<int64, 8> input_strides(input_rank);
|
||||
if (input_rank > 0) {
|
||||
input_strides[input_rank - 1] = 1;
|
||||
@ -174,8 +173,8 @@ void ReshapeSparseTensor(OpKernelContext *context,
|
||||
TensorShape({nnz, output_rank}),
|
||||
&result_indices));
|
||||
if (nnz > 0) {
|
||||
OP_REQUIRES_OK(context, functor::ReshapeSparseTensor<Device>()(
|
||||
context, input_shape, output_shape,
|
||||
OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
|
||||
input_shape, output_shape,
|
||||
input_indices_in.matrix<int64>(),
|
||||
result_indices->matrix<int64>()));
|
||||
}
|
||||
|
@ -37,8 +37,8 @@ void ReshapeSparseTensor(OpKernelContext *context,
|
||||
namespace functor {
|
||||
|
||||
template <typename Device>
|
||||
struct ReshapeSparseTensor {
|
||||
Status operator()(OpKernelContext *context, const TensorShape &input_shape,
|
||||
struct ReshapeSparseTensorFunctor {
|
||||
Status operator()(const TensorShape &input_shape,
|
||||
const TensorShape &output_shape,
|
||||
typename TTypes<int64>::ConstMatrix input_indices,
|
||||
typename TTypes<int64>::Matrix output_indices) const;
|
||||
|
Loading…
x
Reference in New Issue
Block a user