diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc index 2e1510785b3..3acd86ef1a9 100644 --- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc @@ -35,6 +35,8 @@ limitations under the License. namespace tensorflow { +using CPUDevice = Eigen::ThreadPoolDevice; + namespace { using sparse::SparseTensor; @@ -204,9 +206,9 @@ class DeserializeSparseOp : public OpKernel { target_shape.vec()(i + ndims - 1) = output.shape().data()[i + 1]; } - ReshapeSparseTensor(context, output.indices(), input_shape, target_shape, - 0 /* output indices index */, - 2 /* output shape index */); + ReshapeSparseTensor(context, output.indices(), input_shape, + target_shape, 0 /* output indices index */, + 2 /* output shape index */); context->set_output(1, output.values()); } diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc index 1fce80f7970..72fd0ebb871 100644 --- a/tensorflow/core/kernels/reshape_util.cc +++ b/tensorflow/core/kernels/reshape_util.cc @@ -31,6 +31,53 @@ limitations under the License. namespace tensorflow { +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +template <> +struct ReshapeSparseTensorFunctor { + Status operator()(const TensorShape &input_shape, + const TensorShape &output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const { + const int64 input_rank = input_shape.dims(); + const int64 output_rank = output_shape.dims(); + const int64 nnz = input_indices.dimension(0); + gtl::InlinedVector input_strides(input_rank); + if (input_rank > 0) { + input_strides[input_rank - 1] = 1; + for (int d = input_rank - 2; d >= 0; --d) { + input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); + } + } + + gtl::InlinedVector output_strides(output_rank); + if (output_rank > 0) { + output_strides[output_rank - 1] = 1; + for (int d = output_rank - 2; d >= 0; --d) { + output_strides[d] = + output_strides[d + 1] * output_shape.dim_size(d + 1); + } + } + + for (int i = 0; i < nnz; ++i) { + int64 id = 0; + for (int j = 0; j < input_rank; ++j) { + id += input_indices(i, j) * input_strides[j]; + } + for (int j = 0; j < output_rank; ++j) { + output_indices(i, j) = id / output_strides[j]; + id %= output_strides[j]; + } + } + return Status::OK(); + } +}; + +} // namespace functor + +template void ReshapeSparseTensor(OpKernelContext *context, const Tensor &input_indices_in, const Tensor &input_shape_in, @@ -49,7 +96,6 @@ void ReshapeSparseTensor(OpKernelContext *context, "Target shape should be a vector but received shape ", target_shape_in.shape().DebugString())); - const int64 input_rank = input_shape_in.NumElements(); const int64 output_rank = target_shape_in.NumElements(); const TensorShape input_shape(input_shape_in.vec()); const int64 dense_size = input_shape.num_elements(); @@ -111,40 +157,6 @@ void ReshapeSparseTensor(OpKernelContext *context, return; } - gtl::InlinedVector input_strides(input_rank); - if (input_rank > 0) { - input_strides[input_rank - 1] = 1; - for (int d = input_rank - 2; d >= 0; --d) { - input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); - } - } - - gtl::InlinedVector output_strides(output_rank); - if (output_rank > 0) { - output_strides[output_rank - 1] = 1; - for (int d = output_rank - 2; d >= 0; --d) { - output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); - } - } - - Tensor *result_indices = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(output_indices_idx, - TensorShape({nnz, output_rank}), - &result_indices)); - auto input_ind = input_indices_in.matrix(); - auto output_ind = result_indices->matrix(); - for (int i = 0; i < nnz; ++i) { - int64 id = 0; - for (int j = 0; j < input_rank; ++j) { - id += input_ind(i, j) * input_strides[j]; - } - for (int j = 0; j < output_rank; ++j) { - output_ind(i, j) = id / output_strides[j]; - id %= output_strides[j]; - } - } - Tensor *result_shape = nullptr; OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx, TensorShape({output_rank}), @@ -153,6 +165,26 @@ void ReshapeSparseTensor(OpKernelContext *context, for (int j = 0; j < output_shape.dims(); ++j) { output_shape_vec(j) = output_shape.dim_size(j); } + + Tensor *result_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(output_indices_idx, + TensorShape({nnz, output_rank}), + &result_indices)); + if (nnz > 0) { + OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor()( + input_shape, output_shape, + input_indices_in.matrix(), + result_indices->matrix())); + } } +#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \ + template void ReshapeSparseTensor( \ + OpKernelContext * context, const Tensor &input_indices_in, \ + const Tensor &input_shape_in, const Tensor &target_shape_in, \ + int output_indices_idx, int output_shape_idx) +EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice); +#undef EXPLICITLY_INSTANTIATE_FUNCTION + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h index 7e1809e8ca8..b3a35651e63 100644 --- a/tensorflow/core/kernels/reshape_util.h +++ b/tensorflow/core/kernels/reshape_util.h @@ -16,18 +16,36 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ #define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" + namespace tensorflow { class OpKernelContext; class Tensor; // Reshapes the input indices and input shape to the target shape. +// Note: This template is explicitly instantiated for CPU device only. +template void ReshapeSparseTensor(OpKernelContext *context, const Tensor &input_indices_in, const Tensor &input_shape_in, const Tensor &target_shape_in, int output_indices_idx, int output_shape_idx); +namespace functor { + +template +struct ReshapeSparseTensorFunctor { + Status operator()(const TensorShape &input_shape, + const TensorShape &output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const; +}; + +} // namespace functor + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ diff --git a/tensorflow/core/kernels/sparse_reshape_op.cc b/tensorflow/core/kernels/sparse_reshape_op.cc index 6eb5f0af635..472a7a270a5 100644 --- a/tensorflow/core/kernels/sparse_reshape_op.cc +++ b/tensorflow/core/kernels/sparse_reshape_op.cc @@ -29,17 +29,21 @@ limitations under the License. namespace tensorflow { +using CPUDevice = Eigen::ThreadPoolDevice; + +template class SparseReshapeOp : public OpKernel { public: explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - ReshapeSparseTensor(context, context->input(0), context->input(1), - context->input(2), 0 /* output indices index */, - 1 /* output shape index */); + ReshapeSparseTensor(context, context->input(0), context->input(1), + context->input(2), 0 /* output indices index */, + 1 /* output shape index */); } }; REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU), - SparseReshapeOp) + SparseReshapeOp) + } // namespace tensorflow