Merge pull request #46275 from benbarsdell:gpu-SparseReshape-cpu-refactor
PiperOrigin-RevId: 357024042 Change-Id: I63ec2724c86e1def68962a40e375c152dda8fcaa
This commit is contained in:
commit
5eeb7c7338
@ -35,6 +35,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
using sparse::SparseTensor;
|
||||
@ -204,9 +206,9 @@ class DeserializeSparseOp : public OpKernel {
|
||||
target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
|
||||
}
|
||||
|
||||
ReshapeSparseTensor(context, output.indices(), input_shape, target_shape,
|
||||
0 /* output indices index */,
|
||||
2 /* output shape index */);
|
||||
ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape,
|
||||
target_shape, 0 /* output indices index */,
|
||||
2 /* output shape index */);
|
||||
context->set_output(1, output.values());
|
||||
}
|
||||
|
||||
|
||||
@ -31,6 +31,53 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <>
|
||||
struct ReshapeSparseTensorFunctor<CPUDevice> {
|
||||
Status operator()(const TensorShape &input_shape,
|
||||
const TensorShape &output_shape,
|
||||
typename TTypes<int64>::ConstMatrix input_indices,
|
||||
typename TTypes<int64>::Matrix output_indices) const {
|
||||
const int64 input_rank = input_shape.dims();
|
||||
const int64 output_rank = output_shape.dims();
|
||||
const int64 nnz = input_indices.dimension(0);
|
||||
gtl::InlinedVector<int64, 8> input_strides(input_rank);
|
||||
if (input_rank > 0) {
|
||||
input_strides[input_rank - 1] = 1;
|
||||
for (int d = input_rank - 2; d >= 0; --d) {
|
||||
input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
|
||||
}
|
||||
}
|
||||
|
||||
gtl::InlinedVector<int64, 8> output_strides(output_rank);
|
||||
if (output_rank > 0) {
|
||||
output_strides[output_rank - 1] = 1;
|
||||
for (int d = output_rank - 2; d >= 0; --d) {
|
||||
output_strides[d] =
|
||||
output_strides[d + 1] * output_shape.dim_size(d + 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nnz; ++i) {
|
||||
int64 id = 0;
|
||||
for (int j = 0; j < input_rank; ++j) {
|
||||
id += input_indices(i, j) * input_strides[j];
|
||||
}
|
||||
for (int j = 0; j < output_rank; ++j) {
|
||||
output_indices(i, j) = id / output_strides[j];
|
||||
id %= output_strides[j];
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device>
|
||||
void ReshapeSparseTensor(OpKernelContext *context,
|
||||
const Tensor &input_indices_in,
|
||||
const Tensor &input_shape_in,
|
||||
@ -49,7 +96,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
|
||||
"Target shape should be a vector but received shape ",
|
||||
target_shape_in.shape().DebugString()));
|
||||
|
||||
const int64 input_rank = input_shape_in.NumElements();
|
||||
const int64 output_rank = target_shape_in.NumElements();
|
||||
const TensorShape input_shape(input_shape_in.vec<int64>());
|
||||
const int64 dense_size = input_shape.num_elements();
|
||||
@ -111,40 +157,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
|
||||
return;
|
||||
}
|
||||
|
||||
gtl::InlinedVector<int64, 8> input_strides(input_rank);
|
||||
if (input_rank > 0) {
|
||||
input_strides[input_rank - 1] = 1;
|
||||
for (int d = input_rank - 2; d >= 0; --d) {
|
||||
input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
|
||||
}
|
||||
}
|
||||
|
||||
gtl::InlinedVector<int64, 8> output_strides(output_rank);
|
||||
if (output_rank > 0) {
|
||||
output_strides[output_rank - 1] = 1;
|
||||
for (int d = output_rank - 2; d >= 0; --d) {
|
||||
output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor *result_indices = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(output_indices_idx,
|
||||
TensorShape({nnz, output_rank}),
|
||||
&result_indices));
|
||||
auto input_ind = input_indices_in.matrix<int64>();
|
||||
auto output_ind = result_indices->matrix<int64>();
|
||||
for (int i = 0; i < nnz; ++i) {
|
||||
int64 id = 0;
|
||||
for (int j = 0; j < input_rank; ++j) {
|
||||
id += input_ind(i, j) * input_strides[j];
|
||||
}
|
||||
for (int j = 0; j < output_rank; ++j) {
|
||||
output_ind(i, j) = id / output_strides[j];
|
||||
id %= output_strides[j];
|
||||
}
|
||||
}
|
||||
|
||||
Tensor *result_shape = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
|
||||
TensorShape({output_rank}),
|
||||
@ -153,6 +165,26 @@ void ReshapeSparseTensor(OpKernelContext *context,
|
||||
for (int j = 0; j < output_shape.dims(); ++j) {
|
||||
output_shape_vec(j) = output_shape.dim_size(j);
|
||||
}
|
||||
|
||||
Tensor *result_indices = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(output_indices_idx,
|
||||
TensorShape({nnz, output_rank}),
|
||||
&result_indices));
|
||||
if (nnz > 0) {
|
||||
OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
|
||||
input_shape, output_shape,
|
||||
input_indices_in.matrix<int64>(),
|
||||
result_indices->matrix<int64>()));
|
||||
}
|
||||
}
|
||||
|
||||
#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
|
||||
template void ReshapeSparseTensor<Device>( \
|
||||
OpKernelContext * context, const Tensor &input_indices_in, \
|
||||
const Tensor &input_shape_in, const Tensor &target_shape_in, \
|
||||
int output_indices_idx, int output_shape_idx)
|
||||
EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
|
||||
#undef EXPLICITLY_INSTANTIATE_FUNCTION
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -16,18 +16,36 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpKernelContext;
|
||||
class Tensor;
|
||||
|
||||
// Reshapes the input indices and input shape to the target shape.
|
||||
// Note: This template is explicitly instantiated for CPU device only.
|
||||
template <typename Device>
|
||||
void ReshapeSparseTensor(OpKernelContext *context,
|
||||
const Tensor &input_indices_in,
|
||||
const Tensor &input_shape_in,
|
||||
const Tensor &target_shape_in, int output_indices_idx,
|
||||
int output_shape_idx);
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device>
|
||||
struct ReshapeSparseTensorFunctor {
|
||||
Status operator()(const TensorShape &input_shape,
|
||||
const TensorShape &output_shape,
|
||||
typename TTypes<int64>::ConstMatrix input_indices,
|
||||
typename TTypes<int64>::Matrix output_indices) const;
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
|
||||
|
||||
@ -29,17 +29,21 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
template <typename Device>
|
||||
class SparseReshapeOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
ReshapeSparseTensor(context, context->input(0), context->input(1),
|
||||
context->input(2), 0 /* output indices index */,
|
||||
1 /* output shape index */);
|
||||
ReshapeSparseTensor<Device>(context, context->input(0), context->input(1),
|
||||
context->input(2), 0 /* output indices index */,
|
||||
1 /* output shape index */);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU),
|
||||
SparseReshapeOp)
|
||||
SparseReshapeOp<CPUDevice>)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user