Merge pull request #46275 from benbarsdell:gpu-SparseReshape-cpu-refactor

PiperOrigin-RevId: 357024042
Change-Id: I63ec2724c86e1def68962a40e375c152dda8fcaa
This commit is contained in:
TensorFlower Gardener 2021-02-11 11:52:33 -08:00
commit 5eeb7c7338
4 changed files with 98 additions and 42 deletions

View File

@ -35,6 +35,8 @@ limitations under the License.
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
namespace {
using sparse::SparseTensor;
@ -204,9 +206,9 @@ class DeserializeSparseOp : public OpKernel {
target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
}
ReshapeSparseTensor(context, output.indices(), input_shape, target_shape,
0 /* output indices index */,
2 /* output shape index */);
ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape,
target_shape, 0 /* output indices index */,
2 /* output shape index */);
context->set_output(1, output.values());
}

View File

@ -31,6 +31,53 @@ limitations under the License.
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
namespace functor {
template <>
struct ReshapeSparseTensorFunctor<CPUDevice> {
Status operator()(const TensorShape &input_shape,
const TensorShape &output_shape,
typename TTypes<int64>::ConstMatrix input_indices,
typename TTypes<int64>::Matrix output_indices) const {
const int64 input_rank = input_shape.dims();
const int64 output_rank = output_shape.dims();
const int64 nnz = input_indices.dimension(0);
gtl::InlinedVector<int64, 8> input_strides(input_rank);
if (input_rank > 0) {
input_strides[input_rank - 1] = 1;
for (int d = input_rank - 2; d >= 0; --d) {
input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
}
}
gtl::InlinedVector<int64, 8> output_strides(output_rank);
if (output_rank > 0) {
output_strides[output_rank - 1] = 1;
for (int d = output_rank - 2; d >= 0; --d) {
output_strides[d] =
output_strides[d + 1] * output_shape.dim_size(d + 1);
}
}
for (int i = 0; i < nnz; ++i) {
int64 id = 0;
for (int j = 0; j < input_rank; ++j) {
id += input_indices(i, j) * input_strides[j];
}
for (int j = 0; j < output_rank; ++j) {
output_indices(i, j) = id / output_strides[j];
id %= output_strides[j];
}
}
return Status::OK();
}
};
} // namespace functor
template <typename Device>
void ReshapeSparseTensor(OpKernelContext *context,
const Tensor &input_indices_in,
const Tensor &input_shape_in,
@ -49,7 +96,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
"Target shape should be a vector but received shape ",
target_shape_in.shape().DebugString()));
const int64 input_rank = input_shape_in.NumElements();
const int64 output_rank = target_shape_in.NumElements();
const TensorShape input_shape(input_shape_in.vec<int64>());
const int64 dense_size = input_shape.num_elements();
@ -111,40 +157,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
return;
}
gtl::InlinedVector<int64, 8> input_strides(input_rank);
if (input_rank > 0) {
input_strides[input_rank - 1] = 1;
for (int d = input_rank - 2; d >= 0; --d) {
input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
}
}
gtl::InlinedVector<int64, 8> output_strides(output_rank);
if (output_rank > 0) {
output_strides[output_rank - 1] = 1;
for (int d = output_rank - 2; d >= 0; --d) {
output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
}
}
Tensor *result_indices = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(output_indices_idx,
TensorShape({nnz, output_rank}),
&result_indices));
auto input_ind = input_indices_in.matrix<int64>();
auto output_ind = result_indices->matrix<int64>();
for (int i = 0; i < nnz; ++i) {
int64 id = 0;
for (int j = 0; j < input_rank; ++j) {
id += input_ind(i, j) * input_strides[j];
}
for (int j = 0; j < output_rank; ++j) {
output_ind(i, j) = id / output_strides[j];
id %= output_strides[j];
}
}
Tensor *result_shape = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
TensorShape({output_rank}),
@ -153,6 +165,26 @@ void ReshapeSparseTensor(OpKernelContext *context,
for (int j = 0; j < output_shape.dims(); ++j) {
output_shape_vec(j) = output_shape.dim_size(j);
}
Tensor *result_indices = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(output_indices_idx,
TensorShape({nnz, output_rank}),
&result_indices));
if (nnz > 0) {
OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
input_shape, output_shape,
input_indices_in.matrix<int64>(),
result_indices->matrix<int64>()));
}
}
#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
template void ReshapeSparseTensor<Device>( \
OpKernelContext * context, const Tensor &input_indices_in, \
const Tensor &input_shape_in, const Tensor &target_shape_in, \
int output_indices_idx, int output_shape_idx)
EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
#undef EXPLICITLY_INSTANTIATE_FUNCTION
} // namespace tensorflow

View File

@ -16,18 +16,36 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class OpKernelContext;
class Tensor;
// Reshapes the input indices and input shape to the target shape.
// Note: This template is explicitly instantiated for CPU device only.
template <typename Device>
void ReshapeSparseTensor(OpKernelContext *context,
const Tensor &input_indices_in,
const Tensor &input_shape_in,
const Tensor &target_shape_in, int output_indices_idx,
int output_shape_idx);
namespace functor {
template <typename Device>
struct ReshapeSparseTensorFunctor {
Status operator()(const TensorShape &input_shape,
const TensorShape &output_shape,
typename TTypes<int64>::ConstMatrix input_indices,
typename TTypes<int64>::Matrix output_indices) const;
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_

View File

@ -29,17 +29,21 @@ limitations under the License.
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
template <typename Device>
class SparseReshapeOp : public OpKernel {
public:
explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
ReshapeSparseTensor(context, context->input(0), context->input(1),
context->input(2), 0 /* output indices index */,
1 /* output shape index */);
ReshapeSparseTensor<Device>(context, context->input(0), context->input(1),
context->input(2), 0 /* output indices index */,
1 /* output shape index */);
}
};
REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU),
SparseReshapeOp)
SparseReshapeOp<CPUDevice>)
} // namespace tensorflow