From 1d8218f155c1d22c21afda8bf28e36e4094d9e88 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Fri, 8 Jan 2021 11:04:37 +1100 Subject: [PATCH 1/2] Refactor ReshapeSparseTensor into a template+class - This is in preparation for adding a GPU implementation. - No functional change. --- .../kernels/deserialize_sparse_string_op.cc | 8 +- tensorflow/core/kernels/reshape_util.cc | 102 ++++++++++++------ tensorflow/core/kernels/reshape_util.h | 18 ++++ tensorflow/core/kernels/sparse_reshape_op.cc | 12 ++- 4 files changed, 99 insertions(+), 41 deletions(-) diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc index 2e1510785b3..3acd86ef1a9 100644 --- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc @@ -35,6 +35,8 @@ limitations under the License. namespace tensorflow { +using CPUDevice = Eigen::ThreadPoolDevice; + namespace { using sparse::SparseTensor; @@ -204,9 +206,9 @@ class DeserializeSparseOp : public OpKernel { target_shape.vec()(i + ndims - 1) = output.shape().data()[i + 1]; } - ReshapeSparseTensor(context, output.indices(), input_shape, target_shape, - 0 /* output indices index */, - 2 /* output shape index */); + ReshapeSparseTensor(context, output.indices(), input_shape, + target_shape, 0 /* output indices index */, + 2 /* output shape index */); context->set_output(1, output.values()); } diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc index 1fce80f7970..6c3e7169f3f 100644 --- a/tensorflow/core/kernels/reshape_util.cc +++ b/tensorflow/core/kernels/reshape_util.cc @@ -31,6 +31,54 @@ limitations under the License. namespace tensorflow { +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +template <> +struct ReshapeSparseTensor { + Status operator()(OpKernelContext *context, const TensorShape &input_shape, + const TensorShape &output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const { + (void)context; + int64 input_rank = input_shape.dims(); + int64 output_rank = output_shape.dims(); + int64 nnz = input_indices.dimension(0); + gtl::InlinedVector input_strides(input_rank); + if (input_rank > 0) { + input_strides[input_rank - 1] = 1; + for (int d = input_rank - 2; d >= 0; --d) { + input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); + } + } + + gtl::InlinedVector output_strides(output_rank); + if (output_rank > 0) { + output_strides[output_rank - 1] = 1; + for (int d = output_rank - 2; d >= 0; --d) { + output_strides[d] = + output_strides[d + 1] * output_shape.dim_size(d + 1); + } + } + + for (int i = 0; i < nnz; ++i) { + int64 id = 0; + for (int j = 0; j < input_rank; ++j) { + id += input_indices(i, j) * input_strides[j]; + } + for (int j = 0; j < output_rank; ++j) { + output_indices(i, j) = id / output_strides[j]; + id %= output_strides[j]; + } + } + return Status::OK(); + } +}; + +} // namespace functor + +template void ReshapeSparseTensor(OpKernelContext *context, const Tensor &input_indices_in, const Tensor &input_shape_in, @@ -111,40 +159,6 @@ void ReshapeSparseTensor(OpKernelContext *context, return; } - gtl::InlinedVector input_strides(input_rank); - if (input_rank > 0) { - input_strides[input_rank - 1] = 1; - for (int d = input_rank - 2; d >= 0; --d) { - input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); - } - } - - gtl::InlinedVector output_strides(output_rank); - if (output_rank > 0) { - output_strides[output_rank - 1] = 1; - for (int d = output_rank - 2; d >= 0; --d) { - output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); - } - } - - Tensor *result_indices = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(output_indices_idx, - TensorShape({nnz, output_rank}), - &result_indices)); - auto input_ind = input_indices_in.matrix(); - auto output_ind = result_indices->matrix(); - for (int i = 0; i < nnz; ++i) { - int64 id = 0; - for (int j = 0; j < input_rank; ++j) { - id += input_ind(i, j) * input_strides[j]; - } - for (int j = 0; j < output_rank; ++j) { - output_ind(i, j) = id / output_strides[j]; - id %= output_strides[j]; - } - } - Tensor *result_shape = nullptr; OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx, TensorShape({output_rank}), @@ -153,6 +167,26 @@ void ReshapeSparseTensor(OpKernelContext *context, for (int j = 0; j < output_shape.dims(); ++j) { output_shape_vec(j) = output_shape.dim_size(j); } + + Tensor *result_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(output_indices_idx, + TensorShape({nnz, output_rank}), + &result_indices)); + if (nnz > 0) { + OP_REQUIRES_OK(context, functor::ReshapeSparseTensor()( + context, input_shape, output_shape, + input_indices_in.matrix(), + result_indices->matrix())); + } } +#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \ + template void ReshapeSparseTensor( \ + OpKernelContext *context, const Tensor &input_indices_in, \ + const Tensor &input_shape_in, const Tensor &target_shape_in, \ + int output_indices_idx, int output_shape_idx) +EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice); +#undef EXPLICITLY_INSTANTIATE_FUNCTION + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h index 7e1809e8ca8..6471b2daf2d 100644 --- a/tensorflow/core/kernels/reshape_util.h +++ b/tensorflow/core/kernels/reshape_util.h @@ -16,18 +16,36 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ #define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" + namespace tensorflow { class OpKernelContext; class Tensor; // Reshapes the input indices and input shape to the target shape. +// Note: This template is explicitly instantiated for CPU device only. +template void ReshapeSparseTensor(OpKernelContext *context, const Tensor &input_indices_in, const Tensor &input_shape_in, const Tensor &target_shape_in, int output_indices_idx, int output_shape_idx); +namespace functor { + +template +struct ReshapeSparseTensor { + Status operator()(OpKernelContext *context, const TensorShape &input_shape, + const TensorShape &output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const; +}; + +} // namespace functor + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ diff --git a/tensorflow/core/kernels/sparse_reshape_op.cc b/tensorflow/core/kernels/sparse_reshape_op.cc index 6eb5f0af635..7782b10f3db 100644 --- a/tensorflow/core/kernels/sparse_reshape_op.cc +++ b/tensorflow/core/kernels/sparse_reshape_op.cc @@ -29,17 +29,21 @@ limitations under the License. namespace tensorflow { +using CPUDevice = Eigen::ThreadPoolDevice; + +template class SparseReshapeOp : public OpKernel { public: explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - ReshapeSparseTensor(context, context->input(0), context->input(1), - context->input(2), 0 /* output indices index */, - 1 /* output shape index */); + ReshapeSparseTensor( + context, context->input(0), context->input(1), context->input(2), + 0 /* output indices index */, 1 /* output shape index */); } }; REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU), - SparseReshapeOp) + SparseReshapeOp) + } // namespace tensorflow From b1a0dbdf4cd6a943a19d45537b2fd77aa682fc99 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Mon, 25 Jan 2021 17:33:33 +1100 Subject: [PATCH 2/2] Address minor review comments for PR 46275 - Remove unused context parameter. - Rename functor::ReshapeSparseTensor to functor::ReshapeSparseTensorFunctor. - Make integer constants const. --- tensorflow/core/kernels/reshape_util.cc | 15 +++++++-------- tensorflow/core/kernels/reshape_util.h | 4 ++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc index 6c3e7169f3f..d0d54738b27 100644 --- a/tensorflow/core/kernels/reshape_util.cc +++ b/tensorflow/core/kernels/reshape_util.cc @@ -36,15 +36,14 @@ using CPUDevice = Eigen::ThreadPoolDevice; namespace functor { template <> -struct ReshapeSparseTensor { - Status operator()(OpKernelContext *context, const TensorShape &input_shape, +struct ReshapeSparseTensorFunctor { + Status operator()(const TensorShape &input_shape, const TensorShape &output_shape, typename TTypes::ConstMatrix input_indices, typename TTypes::Matrix output_indices) const { - (void)context; - int64 input_rank = input_shape.dims(); - int64 output_rank = output_shape.dims(); - int64 nnz = input_indices.dimension(0); + const int64 input_rank = input_shape.dims(); + const int64 output_rank = output_shape.dims(); + const int64 nnz = input_indices.dimension(0); gtl::InlinedVector input_strides(input_rank); if (input_rank > 0) { input_strides[input_rank - 1] = 1; @@ -174,8 +173,8 @@ void ReshapeSparseTensor(OpKernelContext *context, TensorShape({nnz, output_rank}), &result_indices)); if (nnz > 0) { - OP_REQUIRES_OK(context, functor::ReshapeSparseTensor()( - context, input_shape, output_shape, + OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor()( + input_shape, output_shape, input_indices_in.matrix(), result_indices->matrix())); } diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h index 6471b2daf2d..b3a35651e63 100644 --- a/tensorflow/core/kernels/reshape_util.h +++ b/tensorflow/core/kernels/reshape_util.h @@ -37,8 +37,8 @@ void ReshapeSparseTensor(OpKernelContext *context, namespace functor { template -struct ReshapeSparseTensor { - Status operator()(OpKernelContext *context, const TensorShape &input_shape, +struct ReshapeSparseTensorFunctor { + Status operator()(const TensorShape &input_shape, const TensorShape &output_shape, typename TTypes::ConstMatrix input_indices, typename TTypes::Matrix output_indices) const;