Add bincount to support dense/sparse/ragged inputs.
PiperOrigin-RevId: 310221561 Change-Id: I0b52b452adacc577b79c660ee79d173ecf4c4c56
This commit is contained in:
parent
4a3236c75c
commit
d74866f5f0
46
tensorflow/core/api_def/base_api/api_def_DenseBincount.pbtxt
Normal file
46
tensorflow/core/api_def/base_api/api_def_DenseBincount.pbtxt
Normal file
@ -0,0 +1,46 @@
|
||||
op {
|
||||
graph_op_name: "DenseBincount"
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
1D or 2D int `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "size"
|
||||
description: <<END
|
||||
non-negative int scalar `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "weights"
|
||||
description: <<END
|
||||
is an int32, int64, float32, or float64 `Tensor` with the same
|
||||
shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights
|
||||
equal to 1.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
1D `Tensor` with length equal to `size` or 2D `Tensor` with [batch_size, `size`].
|
||||
The counts or summed weights for each value in the range [0, size).
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "binary_count"
|
||||
description: <<END
|
||||
bool; Whether the kernel should count the appearance or number of occurrences.
|
||||
END
|
||||
}
|
||||
summary: "Counts the number of occurrences of each value in an integer array."
|
||||
description: <<END
|
||||
Outputs a vector with length `size` and the same dtype as `weights`. If
|
||||
`weights` are empty, then index `i` stores the number of times the value `i` is
|
||||
counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of
|
||||
the value in `weights` at each index where the corresponding value in `arr` is
|
||||
`i`.
|
||||
|
||||
Values in `arr` outside of the range [0, size) are ignored.
|
||||
END
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
op {
|
||||
graph_op_name: "RaggedBincount"
|
||||
in_arg {
|
||||
name: "splits"
|
||||
description: <<END
|
||||
1D int64 `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "values"
|
||||
description: <<END
|
||||
2D int `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "size"
|
||||
description: <<END
|
||||
non-negative int scalar `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "weights"
|
||||
description: <<END
|
||||
is an int32, int64, float32, or float64 `Tensor` with the same
|
||||
shape as `input`, or a length-0 `Tensor`, in which case it acts as all weights
|
||||
equal to 1.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
1D `Tensor` with length equal to `size` or 2D `Tensor` with [batch_size, `size`].
|
||||
The counts or summed weights for each value in the range [0, size).
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "binary_count"
|
||||
description: <<END
|
||||
bool; Whether the kernel should count the appearance or number of occurrences.
|
||||
END
|
||||
}
|
||||
summary: "Counts the number of occurrences of each value in an integer array."
|
||||
description: <<END
|
||||
Outputs a vector with length `size` and the same dtype as `weights`. If
|
||||
`weights` are empty, then index `i` stores the number of times the value `i` is
|
||||
counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of
|
||||
the value in `weights` at each index where the corresponding value in `arr` is
|
||||
`i`.
|
||||
|
||||
Values in `arr` outside of the range [0, size) are ignored.
|
||||
END
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
op {
|
||||
graph_op_name: "SparseBincount"
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
2D int64 `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "values"
|
||||
description: <<END
|
||||
1D int `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "dense_shape"
|
||||
description: <<END
|
||||
1D int64 `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "size"
|
||||
description: <<END
|
||||
non-negative int scalar `Tensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "weights"
|
||||
description: <<END
|
||||
is an int32, int64, float32, or float64 `Tensor` with the same
|
||||
shape as `input`, or a length-0 `Tensor`, in which case it acts as all weights
|
||||
equal to 1.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
1D `Tensor` with length equal to `size` or 2D `Tensor` with [batch_size, `size`].
|
||||
The counts or summed weights for each value in the range [0, size).
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "binary_count"
|
||||
description: <<END
|
||||
bool; Whether the kernel should count the appearance or number of occurrences.
|
||||
END
|
||||
}
|
||||
summary: "Counts the number of occurrences of each value in an integer array."
|
||||
description: <<END
|
||||
Outputs a vector with length `size` and the same dtype as `weights`. If
|
||||
`weights` are empty, then index `i` stores the number of times the value `i` is
|
||||
counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of
|
||||
the value in `weights` at each index where the corresponding value in `arr` is
|
||||
`i`.
|
||||
|
||||
Values in `arr` outside of the range [0, size) are ignored.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "DenseBincount"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RaggedBincount"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SparseBincount"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -4871,6 +4871,7 @@ tf_kernel_library(
|
||||
name = "bincount_op",
|
||||
prefix = "bincount_op",
|
||||
deps = [
|
||||
":fill_functor",
|
||||
":gpu_prim_hdrs",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -15,12 +15,14 @@ limitations under the License.
|
||||
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/bincount_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bincount_op.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -33,19 +35,18 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
struct BincountFunctor<CPUDevice, T> {
|
||||
template <typename Tidx, typename T>
|
||||
struct BincountFunctor<CPUDevice, Tidx, T, true> {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<int32, 1>::ConstTensor& arr,
|
||||
const typename TTypes<Tidx, 1>::ConstTensor& arr,
|
||||
const typename TTypes<T, 1>::ConstTensor& weights,
|
||||
typename TTypes<T, 1>::Tensor& output) {
|
||||
int size = output.size();
|
||||
|
||||
typename TTypes<T, 1>::Tensor& output,
|
||||
const Tidx num_bins) {
|
||||
Tensor all_nonneg_t;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes()));
|
||||
all_nonneg_t.scalar<bool>().device(context->eigen_cpu_device()) =
|
||||
(arr >= 0).all();
|
||||
(arr >= Tidx(0)).all();
|
||||
if (!all_nonneg_t.scalar<bool>()()) {
|
||||
return errors::InvalidArgument("Input arr must be non-negative!");
|
||||
}
|
||||
@ -56,17 +57,62 @@ struct BincountFunctor<CPUDevice, T> {
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
const int64 num_threads = thread_pool->NumThreads() + 1;
|
||||
Tensor partial_bins_t;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
TensorShape({num_threads, size}),
|
||||
&partial_bins_t));
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DT_BOOL, TensorShape({num_threads, num_bins}), &partial_bins_t));
|
||||
auto partial_bins = partial_bins_t.matrix<bool>();
|
||||
partial_bins.setZero();
|
||||
thread_pool->ParallelForWithWorkerId(
|
||||
arr.size(), 8 /* cost */,
|
||||
[&](int64 start_ind, int64 limit_ind, int64 worker_id) {
|
||||
for (int64 i = start_ind; i < limit_ind; i++) {
|
||||
Tidx value = arr(i);
|
||||
if (value < num_bins) {
|
||||
partial_bins(worker_id, value) = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Sum the partial bins along the 0th axis.
|
||||
Eigen::array<int, 1> reduce_dim({0});
|
||||
output.device(context->eigen_cpu_device()) =
|
||||
partial_bins.any(reduce_dim).cast<T>();
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tidx, typename T>
|
||||
struct BincountFunctor<CPUDevice, Tidx, T, false> {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<Tidx, 1>::ConstTensor& arr,
|
||||
const typename TTypes<T, 1>::ConstTensor& weights,
|
||||
typename TTypes<T, 1>::Tensor& output,
|
||||
const Tidx num_bins) {
|
||||
Tensor all_nonneg_t;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes()));
|
||||
all_nonneg_t.scalar<bool>().device(context->eigen_cpu_device()) =
|
||||
(arr >= Tidx(0)).all();
|
||||
if (!all_nonneg_t.scalar<bool>()()) {
|
||||
return errors::InvalidArgument("Input arr must be non-negative!");
|
||||
}
|
||||
|
||||
// Allocate partial output bin sums for each worker thread. Worker ids in
|
||||
// ParallelForWithWorkerId range from 0 to NumThreads() inclusive.
|
||||
ThreadPool* thread_pool =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
const int64 num_threads = thread_pool->NumThreads() + 1;
|
||||
Tensor partial_bins_t;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataTypeToEnum<T>::value, TensorShape({num_threads, num_bins}),
|
||||
&partial_bins_t));
|
||||
auto partial_bins = partial_bins_t.matrix<T>();
|
||||
partial_bins.setZero();
|
||||
thread_pool->ParallelForWithWorkerId(
|
||||
arr.size(), 8 /* cost */,
|
||||
[&](int64 start_ind, int64 limit_ind, int64 worker_id) {
|
||||
for (int64 i = start_ind; i < limit_ind; i++) {
|
||||
int32 value = arr(i);
|
||||
if (value < size) {
|
||||
Tidx value = arr(i);
|
||||
if (value < num_bins) {
|
||||
if (weights.size()) {
|
||||
partial_bins(worker_id, value) += weights(i);
|
||||
} else {
|
||||
@ -78,8 +124,43 @@ struct BincountFunctor<CPUDevice, T> {
|
||||
});
|
||||
|
||||
// Sum the partial bins along the 0th axis.
|
||||
Eigen::array<int, 1> reduce_dims({0});
|
||||
output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dims);
|
||||
Eigen::array<int, 1> reduce_dim({0});
|
||||
output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dim);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tidx, typename T, bool binary_count>
|
||||
struct BincountReduceFunctor<CPUDevice, Tidx, T, binary_count> {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<Tidx, 2>::ConstTensor& in,
|
||||
const typename TTypes<T, 2>::ConstTensor& weights,
|
||||
typename TTypes<T, 2>::Tensor& out,
|
||||
const Tidx num_bins) {
|
||||
const int num_rows = out.dimension(0);
|
||||
const int num_cols = in.dimension(1);
|
||||
ThreadPool* thread_pool =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
thread_pool->ParallelForWithWorkerId(
|
||||
num_rows, 8 /* cost */,
|
||||
[&](int64 start_row, int64 end_row, int64 worker_id) {
|
||||
for (int64 i = start_row; i < end_row; ++i) {
|
||||
for (int64 j = 0; j < num_cols; ++j) {
|
||||
Tidx value = in(i, j);
|
||||
if (value < num_bins) {
|
||||
if (binary_count) {
|
||||
out(i, value) = T(1);
|
||||
} else {
|
||||
if (weights.size()) {
|
||||
out(i, value) += weights(i, j);
|
||||
} else {
|
||||
out(i, value) += T(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
@ -107,8 +188,9 @@ class BincountOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({size}), &output_t));
|
||||
auto output = output_t->flat<T>();
|
||||
OP_REQUIRES_OK(ctx, functor::BincountFunctor<Device, T>::Compute(
|
||||
ctx, arr, weights, output));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
functor::BincountFunctor<Device, int32, T, false>::Compute(
|
||||
ctx, arr, weights, output, size));
|
||||
}
|
||||
};
|
||||
|
||||
@ -135,4 +217,244 @@ TF_CALL_float(REGISTER_KERNELS);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename Tidx, typename T>
|
||||
class DenseBincountOp : public OpKernel {
|
||||
public:
|
||||
explicit DenseBincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_count", &binary_count_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& data = ctx->input(0);
|
||||
const Tensor& size_t = ctx->input(1);
|
||||
const Tensor& weights = ctx->input(2);
|
||||
|
||||
Tidx size = size_t.scalar<Tidx>()();
|
||||
OP_REQUIRES(
|
||||
ctx, size >= 0,
|
||||
errors::InvalidArgument("size (", size, ") must be non-negative"));
|
||||
|
||||
Tensor* out_t;
|
||||
functor::SetZeroFunctor<Device, T> fill;
|
||||
if (data.dims() == 1) {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({size}), &out_t));
|
||||
auto out = out_t->flat<T>();
|
||||
fill(ctx->eigen_device<Device>(), out);
|
||||
if (binary_count_) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, functor::BincountFunctor<Device, Tidx, T, true>::Compute(
|
||||
ctx, data.flat<Tidx>(), weights.flat<T>(), out, size));
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, functor::BincountFunctor<Device, Tidx, T, false>::Compute(
|
||||
ctx, data.flat<Tidx>(), weights.flat<T>(), out, size));
|
||||
}
|
||||
} else if (data.dims() == 2) {
|
||||
const int64 num_rows = data.dim_size(0);
|
||||
auto weight_matrix =
|
||||
(weights.NumElements() == 0)
|
||||
? weights.shaped<T, 2>(gtl::InlinedVector<int64, 2>(2, 0))
|
||||
: weights.matrix<T>();
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({num_rows, size}), &out_t));
|
||||
auto out = out_t->matrix<T>();
|
||||
fill(ctx->eigen_device<Device>(), out_t->flat<T>());
|
||||
if (binary_count_) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, functor::BincountReduceFunctor<Device, Tidx, T, true>::Compute(
|
||||
ctx, data.matrix<Tidx>(), weight_matrix, out, size));
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
functor::BincountReduceFunctor<Device, Tidx, T, false>::Compute(
|
||||
ctx, data.matrix<Tidx>(), weight_matrix, out, size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool binary_count_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(Tidx, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DenseBincount") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tidx>("Tidx"), \
|
||||
DenseBincountOp<CPUDevice, Tidx, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(int32, T); \
|
||||
REGISTER_KERNELS(int64, T);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_KERNELS(Tidx, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DenseBincount") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("size") \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tidx>("Tidx"), \
|
||||
DenseBincountOp<GPUDevice, Tidx, T>);
|
||||
#define REGISTER_GPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(int32, T); \
|
||||
REGISTER_KERNELS(int64, T);
|
||||
|
||||
TF_CALL_int32(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename Tidx, typename T>
|
||||
class SparseBincountOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseBincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_count", &binary_count_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& indices = ctx->input(0);
|
||||
const auto values = ctx->input(1).flat<Tidx>();
|
||||
const Tensor& dense_shape = ctx->input(2);
|
||||
const Tensor& size_t = ctx->input(3);
|
||||
const auto weights = ctx->input(4).flat<T>();
|
||||
const int64 weights_size = weights.size();
|
||||
|
||||
Tidx size = size_t.scalar<Tidx>()();
|
||||
OP_REQUIRES(
|
||||
ctx, size >= 0,
|
||||
errors::InvalidArgument("size (", size, ") must be non-negative"));
|
||||
|
||||
bool is_1d = dense_shape.NumElements() == 1;
|
||||
|
||||
Tensor* out_t;
|
||||
functor::SetZeroFunctor<Device, T> fill;
|
||||
if (is_1d) {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({size}), &out_t));
|
||||
auto out = out_t->flat<T>();
|
||||
fill(ctx->eigen_device<Device>(), out);
|
||||
if (binary_count_) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
functor::BincountFunctor<Device, Tidx, T, true>::Compute(
|
||||
ctx, values, weights, out, size));
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, functor::BincountFunctor<Device, Tidx, T, false>::Compute(
|
||||
ctx, values, weights, out, size));
|
||||
}
|
||||
} else {
|
||||
const auto shape = dense_shape.flat<int64>();
|
||||
const int64 num_rows = shape(0);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({num_rows, size}), &out_t));
|
||||
const auto out = out_t->matrix<T>();
|
||||
fill(ctx->eigen_device<Device>(), out_t->flat<T>());
|
||||
const auto indices_mat = indices.matrix<int64>();
|
||||
for (int64 i = 0; i < indices_mat.dimension(0); ++i) {
|
||||
const int64 batch = indices_mat(i, 0);
|
||||
const Tidx bin = values(i);
|
||||
if (bin < size) {
|
||||
if (binary_count_) {
|
||||
out(batch, bin) = T(1);
|
||||
} else {
|
||||
if (weights_size) {
|
||||
out(batch, bin) += weights(i);
|
||||
} else {
|
||||
out(batch, bin) += T(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool binary_count_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(Tidx, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseBincount") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tidx>("Tidx"), \
|
||||
SparseBincountOp<CPUDevice, Tidx, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(int32, T); \
|
||||
REGISTER_KERNELS(int64, T);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename Tidx, typename T>
|
||||
class RaggedBincountOp : public OpKernel {
|
||||
public:
|
||||
explicit RaggedBincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_count", &binary_count_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const auto splits = ctx->input(0).flat<int64>();
|
||||
const auto values = ctx->input(1).flat<Tidx>();
|
||||
const Tensor& size_t = ctx->input(2);
|
||||
const auto weights = ctx->input(3).flat<T>();
|
||||
const int64 weights_size = weights.size();
|
||||
|
||||
Tidx size = size_t.scalar<Tidx>()();
|
||||
OP_REQUIRES(
|
||||
ctx, size >= 0,
|
||||
errors::InvalidArgument("size (", size, ") must be non-negative"));
|
||||
|
||||
int num_rows = splits.size() - 1;
|
||||
int num_values = values.size();
|
||||
int batch_idx = 0;
|
||||
|
||||
Tensor* out_t;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({num_rows, size}), &out_t));
|
||||
functor::SetZeroFunctor<Device, T> fill;
|
||||
fill(ctx->eigen_device<Device>(), out_t->flat<T>());
|
||||
const auto out = out_t->matrix<T>();
|
||||
|
||||
for (int idx = 0; idx < num_values; ++idx) {
|
||||
while (idx >= splits(batch_idx)) {
|
||||
batch_idx++;
|
||||
}
|
||||
Tidx bin = values(idx);
|
||||
OP_REQUIRES(ctx, bin >= 0,
|
||||
errors::InvalidArgument("Input must be non-negative"));
|
||||
if (bin < size) {
|
||||
if (binary_count_) {
|
||||
out(batch_idx - 1, bin) = T(1);
|
||||
} else {
|
||||
T value = (weights_size > 0) ? weights(idx) : T(1);
|
||||
out(batch_idx - 1, bin) += value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool binary_count_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(Tidx, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RaggedBincount") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tidx>("Tidx"), \
|
||||
RaggedBincountOp<CPUDevice, Tidx, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(int32, T); \
|
||||
REGISTER_KERNELS(int64, T);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -26,12 +26,22 @@ namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename Tidx, typename T, bool binary_count>
|
||||
struct BincountFunctor {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<int32, 1>::ConstTensor& arr,
|
||||
const typename TTypes<Tidx, 1>::ConstTensor& arr,
|
||||
const typename TTypes<T, 1>::ConstTensor& weights,
|
||||
typename TTypes<T, 1>::Tensor& output);
|
||||
typename TTypes<T, 1>::Tensor& output,
|
||||
const Tidx num_bins);
|
||||
};
|
||||
|
||||
template <typename Device, typename Tidx, typename T, bool binary_count>
|
||||
struct BincountReduceFunctor {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<Tidx, 2>::ConstTensor& in,
|
||||
const typename TTypes<T, 2>::ConstTensor& weights,
|
||||
typename TTypes<T, 2>::Tensor& out,
|
||||
const Tidx num_bins);
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
@ -33,12 +33,13 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
struct BincountFunctor<GPUDevice, T> {
|
||||
template <typename Tidx, typename T>
|
||||
struct BincountFunctor<GPUDevice, Tidx, T, false> {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<int32, 1>::ConstTensor& arr,
|
||||
const typename TTypes<Tidx, 1>::ConstTensor& arr,
|
||||
const typename TTypes<T, 1>::ConstTensor& weights,
|
||||
typename TTypes<T, 1>::Tensor& output) {
|
||||
typename TTypes<T, 1>::Tensor& output,
|
||||
const Tidx num_bins) {
|
||||
if (weights.size() != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Weights should not be passed as it should be "
|
||||
@ -49,11 +50,11 @@ struct BincountFunctor<GPUDevice, T> {
|
||||
}
|
||||
// In case weight.size() == 0, use CUB
|
||||
size_t temp_storage_bytes = 0;
|
||||
const int32* d_samples = arr.data();
|
||||
const Tidx* d_samples = arr.data();
|
||||
T* d_histogram = output.data();
|
||||
int num_levels = output.size() + 1;
|
||||
int32 lower_level = 0;
|
||||
int32 upper_level = output.size();
|
||||
Tidx lower_level = Tidx(0);
|
||||
Tidx upper_level = num_bins;
|
||||
int num_samples = arr.size();
|
||||
const gpuStream_t& stream = GetGpuStream(context);
|
||||
|
||||
@ -100,10 +101,142 @@ struct BincountFunctor<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tidx, typename T>
|
||||
__global__ void BincountReduceKernel(const Tidx* in, T* out, const int nthreads,
|
||||
const Tidx num_bins) {
|
||||
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
||||
Tidx bin = ldg(in + index);
|
||||
if (bin < num_bins) {
|
||||
out[bin] = T(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tidx, typename T>
|
||||
struct BincountFunctor<GPUDevice, Tidx, T, true> {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<Tidx, 1>::ConstTensor& arr,
|
||||
const typename TTypes<T, 1>::ConstTensor& weights,
|
||||
typename TTypes<T, 1>::Tensor& output,
|
||||
const Tidx num_bins) {
|
||||
const int nthreads = arr.dimension(0);
|
||||
|
||||
auto d = context->eigen_gpu_device();
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(nthreads, d);
|
||||
return GpuLaunchKernel(BincountReduceKernel<Tidx, T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), arr.data(),
|
||||
output.data(), nthreads, num_bins);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tidx, typename T, bool binary_count>
|
||||
__global__ void BincountColReduceKernel(const Tidx* in, const T* weights,
|
||||
const int weights_size, T* out,
|
||||
const int num_rows, const int num_cols,
|
||||
const Tidx num_bins) {
|
||||
const int nthreads = num_rows * num_cols;
|
||||
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
||||
Tidx bin = ldg(in + index);
|
||||
if (bin < num_bins) {
|
||||
int row = index / num_cols;
|
||||
int offset = row * num_bins + bin;
|
||||
if (binary_count) {
|
||||
out[offset] = T(1);
|
||||
} else {
|
||||
T value = (weights_size == 0) ? T(1) : ldg(weights + index);
|
||||
GpuAtomicAdd(out + offset, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tidx, typename T, bool binary_count>
|
||||
__global__ void BincountColReduceSharedKernel(const Tidx* in, const T* weights,
|
||||
const int weights_size, T* out,
|
||||
const int num_rows,
|
||||
const int num_cols,
|
||||
const Tidx num_bins) {
|
||||
const int out_size = num_rows * num_bins;
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(T), unsigned char, shared_col_mem);
|
||||
T* shared_col_bins = reinterpret_cast<T*>(shared_col_mem);
|
||||
for (unsigned int binIdx = threadIdx.x; binIdx < out_size;
|
||||
binIdx += blockDim.x) {
|
||||
shared_col_bins[binIdx] = T(0);
|
||||
}
|
||||
__syncthreads();
|
||||
const int nthreads = num_rows * num_cols;
|
||||
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
||||
Tidx bin = ldg(in + index);
|
||||
if (bin < num_bins) {
|
||||
int row = index / num_cols;
|
||||
int offset = row * num_bins + bin;
|
||||
if (binary_count) {
|
||||
shared_col_bins[offset] = T(1);
|
||||
} else {
|
||||
T value = (weights_size == 0) ? T(1) : ldg(weights + index);
|
||||
GpuAtomicAdd(shared_col_bins + offset, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (unsigned int binIdx = threadIdx.x; binIdx < out_size;
|
||||
binIdx += blockDim.x) {
|
||||
if (binary_count) {
|
||||
// out[binIdx] = out[binIdx] & shared_col_bins[binIdx];
|
||||
if (shared_col_bins[binIdx]) {
|
||||
out[binIdx] = shared_col_bins[binIdx];
|
||||
}
|
||||
} else {
|
||||
GpuAtomicAdd(out + binIdx, shared_col_bins[binIdx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tidx, typename T, bool binary_count>
|
||||
struct BincountReduceFunctor<GPUDevice, Tidx, T, binary_count> {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<Tidx, 2>::ConstTensor& in,
|
||||
const typename TTypes<T, 2>::ConstTensor& weights,
|
||||
typename TTypes<T, 2>::Tensor& out,
|
||||
const Tidx num_bins) {
|
||||
const int num_rows = in.dimension(0);
|
||||
const int num_cols = in.dimension(1);
|
||||
|
||||
auto d = context->eigen_gpu_device();
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(num_rows * num_cols, d);
|
||||
|
||||
// Use half of maximum shared memory, approximately 6 * 1024 inputs.
|
||||
int smem_max = d.sharedMemPerBlock() / 2;
|
||||
int smem_usage = out.size() * sizeof(T);
|
||||
if (smem_usage < smem_max) {
|
||||
return GpuLaunchKernel(
|
||||
BincountColReduceSharedKernel<Tidx, T, binary_count>,
|
||||
config.block_count, config.thread_per_block, smem_usage, d.stream(),
|
||||
in.data(), weights.data(), weights.size(), out.data(), num_rows,
|
||||
num_cols, num_bins);
|
||||
} else {
|
||||
return GpuLaunchKernel(
|
||||
BincountColReduceKernel<Tidx, T, binary_count>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), in.data(), weights.data(),
|
||||
weights.size(), out.data(), num_rows, num_cols, num_bins);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_GPU_SPEC(type) \
|
||||
template struct functor::BincountFunctor<GPUDevice, type>;
|
||||
#define REGISTER_GPU_SPEC(T) \
|
||||
template struct functor::BincountFunctor<GPUDevice, int32, T, true>; \
|
||||
template struct functor::BincountFunctor<GPUDevice, int64, T, true>; \
|
||||
template struct functor::BincountFunctor<GPUDevice, int32, T, false>; \
|
||||
template struct functor::BincountFunctor<GPUDevice, int64, T, false>; \
|
||||
template struct functor::BincountReduceFunctor<GPUDevice, int32, T, true>; \
|
||||
template struct functor::BincountReduceFunctor<GPUDevice, int64, T, true>; \
|
||||
template struct functor::BincountReduceFunctor<GPUDevice, int32, T, false>; \
|
||||
template struct functor::BincountReduceFunctor<GPUDevice, int64, T, false>;
|
||||
|
||||
TF_CALL_int32(REGISTER_GPU_SPEC);
|
||||
TF_CALL_float(REGISTER_GPU_SPEC);
|
||||
|
@ -1651,6 +1651,116 @@ REGISTER_OP("Bincount")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("DenseBincount")
|
||||
.Input("input: Tidx")
|
||||
.Input("size: Tidx")
|
||||
.Input("weights: T")
|
||||
.Attr("Tidx: {int32, int64}")
|
||||
.Attr("T: {int32, int64, float32, float64}")
|
||||
.Attr("binary_count: bool = false")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
// The input `input` must be at most matrix.
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 2, &unused));
|
||||
// The input `size` must be a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
|
||||
const Tensor* size_tensor = c->input_tensor(1);
|
||||
if (size_tensor == nullptr) {
|
||||
// Return unknown shape if size is not known.
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 size_val;
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
|
||||
if (dtype == DT_INT32) {
|
||||
size_val = static_cast<int64>(size_tensor->scalar<int32>()());
|
||||
} else if (dtype == DT_INT64) {
|
||||
size_val = size_tensor->scalar<int64>()();
|
||||
} else {
|
||||
return errors::InvalidArgument("size dtype must be int32 or int64");
|
||||
}
|
||||
// Return `[size]` shape if size is known.
|
||||
if (size_val < 0) {
|
||||
return errors::InvalidArgument("size (", size_val,
|
||||
") must be non-negative");
|
||||
}
|
||||
if (c->Rank(c->input(0)) == 1) {
|
||||
c->set_output(0, c->MakeShape({size_val}));
|
||||
} else if (c->Rank(c->input(0)) == 2) {
|
||||
c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val}));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("SparseBincount")
|
||||
.Input("indices: int64")
|
||||
.Input("values: Tidx")
|
||||
.Input("dense_shape: int64")
|
||||
.Input("size: Tidx")
|
||||
.Input("weights: T")
|
||||
.Attr("Tidx: {int32, int64}")
|
||||
.Attr("T: {int32, int64, float32, float64}")
|
||||
.Attr("binary_count: bool = false")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
const Tensor* size_tensor = c->input_tensor(3);
|
||||
if (size_tensor == nullptr) {
|
||||
// Return unknown shape if size is not known.
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 size_val;
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
|
||||
if (dtype == DT_INT32) {
|
||||
size_val = static_cast<int64>(size_tensor->scalar<int32>()());
|
||||
} else if (dtype == DT_INT64) {
|
||||
size_val = size_tensor->scalar<int64>()();
|
||||
} else {
|
||||
return errors::InvalidArgument("size dtype must be int32 or int64");
|
||||
}
|
||||
// Return `[size]` shape if size is known.
|
||||
if (size_val < 0) {
|
||||
return errors::InvalidArgument("size (", size_val,
|
||||
") must be non-negative");
|
||||
}
|
||||
|
||||
const Tensor* shape_tensor = c->input_tensor(2);
|
||||
if (shape_tensor == nullptr) {
|
||||
// Return unknown shape if size is not known.
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
if (shape_tensor->NumElements() == 1) {
|
||||
c->set_output(0, c->MakeShape({size_val}));
|
||||
} else if (shape_tensor->NumElements() == 2) {
|
||||
c->set_output(0,
|
||||
c->MakeShape({shape_tensor->flat<int64>()(0), size_val}));
|
||||
} else {
|
||||
return errors::InvalidArgument("Input must be less than rank 2");
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("RaggedBincount")
|
||||
.Input("splits: int64")
|
||||
.Input("values: Tidx")
|
||||
.Input("size: Tidx")
|
||||
.Input("weights: T")
|
||||
.Attr("Tidx: {int32, int64}")
|
||||
.Attr("T: {int32, int64, float32, float64}")
|
||||
.Attr("binary_count: bool = false")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("Cumsum")
|
||||
.Input("x: T")
|
||||
.Input("axis: Tidx")
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -26,6 +27,9 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@ -128,5 +132,505 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(v2.get_shape().as_list(), [None])
|
||||
|
||||
|
||||
class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_bincount_all_count(self, dtype):
|
||||
np.random.seed(42)
|
||||
size = 1000
|
||||
inp = np.random.randint(0, size, (4096), dtype=dtype)
|
||||
np_out = np.bincount(inp, minlength=size)
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(input=inp, weights=[], size=size)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_bincount_all_count_with_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
size = 1000
|
||||
inp = np.random.randint(0, size, (4096,), dtype=dtype)
|
||||
np_weight = np.random.random((4096,))
|
||||
np_out = np.bincount(inp, minlength=size, weights=np_weight)
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=np_weight, size=size)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_bincount_all_binary(self, dtype):
|
||||
np.random.seed(42)
|
||||
size = 10
|
||||
inp = np.random.randint(0, size, (4096), dtype=dtype)
|
||||
np_out = np.ones((size,))
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=[], size=size, binary_count=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_bincount_all_binary_with_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
size = 10
|
||||
inp = np.random.randint(0, size, (4096,), dtype=dtype)
|
||||
np_weight = np.random.random((4096,))
|
||||
np_out = np.ones((size,))
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=np_weight, size=size, binary_count=True)))
|
||||
|
||||
def _test_bincount_col_count(self, num_rows, num_cols, size, dtype):
|
||||
np.random.seed(42)
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate(
|
||||
[np.bincount(inp[j, :], minlength=size) for j in range(num_rows)],
|
||||
axis=0), (num_rows, size))
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(input=inp, weights=[], size=size)))
|
||||
|
||||
def _test_bincount_col_binary(self, num_rows, num_cols, size, dtype):
|
||||
np.random.seed(42)
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate([
|
||||
np.where(np.bincount(inp[j, :], minlength=size) > 0, 1, 0)
|
||||
for j in range(num_rows)
|
||||
],
|
||||
axis=0), (num_rows, size))
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=[], size=size, binary_count=True)))
|
||||
|
||||
def _test_bincount_col_count_with_weights(self, num_rows, num_cols, size,
|
||||
dtype):
|
||||
np.random.seed(42)
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_weight = np.random.random((num_rows, num_cols))
|
||||
np_out = np.reshape(
|
||||
np.concatenate([
|
||||
np.bincount(inp[j, :], weights=np_weight[j, :], minlength=size)
|
||||
for j in range(num_rows)
|
||||
],
|
||||
axis=0), (num_rows, size))
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=np_weight, size=size)))
|
||||
|
||||
def test_col_reduce_basic(self):
|
||||
with test_util.use_gpu():
|
||||
v = self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=[[1, 2, 3], [0, 3, 2]], weights=[], size=4))
|
||||
expected_out = [[0., 1., 1., 1.], [1., 0., 1., 1.]]
|
||||
self.assertAllEqual(expected_out, v)
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_col_reduce_shared_memory(self, dtype):
|
||||
# num_rows * num_bins less than half of max shared memory.
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 10
|
||||
self._test_bincount_col_count(num_rows, num_cols, size, dtype)
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_col_reduce_global_memory(self, dtype):
|
||||
# num_rows * num_bins more than half of max shared memory.
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 1024
|
||||
self._test_bincount_col_count(num_rows, num_cols, size, dtype)
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_col_reduce_shared_memory_with_weights(self, dtype):
|
||||
# num_rows * num_bins less than half of max shared memory.
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 100
|
||||
self._test_bincount_col_count_with_weights(num_rows, num_cols, size, dtype)
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_col_reduce_global_memory_with_weights(self, dtype):
|
||||
# num_rows * num_bins more than half of max shared memory.
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 1024
|
||||
self._test_bincount_col_count_with_weights(num_rows, num_cols, size, dtype)
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_col_reduce_binary(self, dtype):
|
||||
num_rows = 128
|
||||
num_cols = 7
|
||||
size = 10
|
||||
self._test_bincount_col_binary(num_rows, num_cols, size, dtype)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_invalid_rank(self):
|
||||
with self.assertRaisesRegexp(ValueError, "at most rank 2"):
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=[[[1, 2, 3], [0, 3, 2]]], weights=[], size=10))
|
||||
|
||||
|
||||
class SparseBincountOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_bincount_all_count(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
size = 1000
|
||||
n_elems = 4096
|
||||
inp_indices = np.random.randint(0, num_rows, (n_elems,))
|
||||
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
|
||||
|
||||
np_out = np.bincount(inp_vals, minlength=size)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.sparse_bincount(
|
||||
indices=inp_indices,
|
||||
values=inp_vals,
|
||||
dense_shape=[num_rows],
|
||||
size=size,
|
||||
weights=[])))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_bincount_all_count_with_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
size = 1000
|
||||
n_elems = 4096
|
||||
inp_indices = np.random.randint(0, num_rows, (n_elems,))
|
||||
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
|
||||
inp_weight = np.random.random((n_elems,))
|
||||
|
||||
np_out = np.bincount(inp_vals, minlength=size, weights=inp_weight)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.sparse_bincount(
|
||||
indices=inp_indices,
|
||||
values=inp_vals,
|
||||
dense_shape=[num_rows],
|
||||
size=size,
|
||||
weights=inp_weight)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_bincount_all_binary(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
size = 10
|
||||
n_elems = 4096
|
||||
inp_indices = np.random.randint(0, num_rows, (n_elems,))
|
||||
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
|
||||
|
||||
np_out = np.ones((size,))
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.sparse_bincount(
|
||||
indices=inp_indices,
|
||||
values=inp_vals,
|
||||
dense_shape=[num_rows],
|
||||
size=size,
|
||||
weights=[],
|
||||
binary_count=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_bincount_all_binary_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
size = 10
|
||||
n_elems = 4096
|
||||
inp_indices = np.random.randint(0, num_rows, (n_elems,))
|
||||
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
|
||||
inp_weight = np.random.random((n_elems,))
|
||||
|
||||
np_out = np.ones((size,))
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.sparse_bincount(
|
||||
indices=inp_indices,
|
||||
values=inp_vals,
|
||||
dense_shape=[num_rows],
|
||||
size=size,
|
||||
weights=inp_weight,
|
||||
binary_count=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_bincount_col_reduce_count(self, dtype):
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 100
|
||||
np.random.seed(42)
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate(
|
||||
[np.bincount(inp[j, :], minlength=size) for j in range(num_rows)],
|
||||
axis=0), (num_rows, size))
|
||||
# from_dense will filter out 0s.
|
||||
inp = inp + 1
|
||||
# from_dense will cause OOM in GPU.
|
||||
with ops.device("/CPU:0"):
|
||||
inp_sparse = sparse_ops.from_dense(inp)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.sparse_bincount(
|
||||
indices=inp_sparse.indices,
|
||||
values=inp_sparse.values - 1,
|
||||
dense_shape=inp_sparse.dense_shape,
|
||||
size=size,
|
||||
weights=[])))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_bincount_col_reduce_binary(self, dtype):
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 100
|
||||
np.random.seed(42)
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate([
|
||||
np.where(np.bincount(inp[j, :], minlength=size) > 0, 1, 0)
|
||||
for j in range(num_rows)
|
||||
],
|
||||
axis=0), (num_rows, size))
|
||||
# from_dense will filter out 0s.
|
||||
inp = inp + 1
|
||||
# from_dense will cause OOM in GPU.
|
||||
with ops.device("/CPU:0"):
|
||||
inp_sparse = sparse_ops.from_dense(inp)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.sparse_bincount(
|
||||
indices=inp_sparse.indices,
|
||||
values=inp_sparse.values - 1,
|
||||
dense_shape=inp_sparse.dense_shape,
|
||||
size=size,
|
||||
weights=[],
|
||||
binary_count=True)))
|
||||
|
||||
|
||||
class RaggedBincountOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_bincount_count(self, dtype):
|
||||
x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]])
|
||||
expected_output = [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,
|
||||
0], [1, 1, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 2, 1]]
|
||||
self.assertAllEqual(
|
||||
expected_output,
|
||||
self.evaluate(
|
||||
gen_math_ops.ragged_bincount(
|
||||
splits=x.row_splits, values=x.values, weights=[], size=6)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_bincount_binary(self, dtype):
|
||||
x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]])
|
||||
expected_output = [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,
|
||||
0], [1, 1, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
|
||||
self.assertAllEqual(
|
||||
expected_output,
|
||||
self.evaluate(
|
||||
gen_math_ops.ragged_bincount(
|
||||
splits=x.row_splits,
|
||||
values=x.values,
|
||||
weights=[],
|
||||
size=6,
|
||||
binary_count=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_bincount_count_with_weights(self, dtype):
|
||||
x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]])
|
||||
weights = ragged_factory_ops.constant([[], [], [.1, .2, .3], [],
|
||||
[.2, .5, .6, .3]])
|
||||
expected_output = [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
|
||||
[.2, .3, 0, .1, 0, 0], [0, 0, 0, 0, 0, 0],
|
||||
[.5, 0, 0, 0, .9, .2]]
|
||||
self.assertAllClose(
|
||||
expected_output,
|
||||
self.evaluate(
|
||||
gen_math_ops.ragged_bincount(
|
||||
splits=x.row_splits,
|
||||
values=x.values,
|
||||
weights=weights.values,
|
||||
size=6)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_bincount_count_np(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 1000
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate(
|
||||
[np.bincount(inp[j, :], minlength=size) for j in range(num_rows)],
|
||||
axis=0), (num_rows, size))
|
||||
x = ragged_tensor.RaggedTensor.from_tensor(inp)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.ragged_bincount(
|
||||
splits=x.row_splits, values=x.values, weights=[], size=size)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_bincount_count_np_with_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 1000
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_weight = np.random.random((num_rows, num_cols))
|
||||
np_out = np.reshape(
|
||||
np.concatenate([
|
||||
np.bincount(inp[j, :], weights=np_weight[j, :], minlength=size)
|
||||
for j in range(num_rows)
|
||||
],
|
||||
axis=0), (num_rows, size))
|
||||
x = ragged_tensor.RaggedTensor.from_tensor(inp)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.ragged_bincount(
|
||||
splits=x.row_splits,
|
||||
values=x.values,
|
||||
weights=np_weight,
|
||||
size=size)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_bincount_binary_np_with_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 1000
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate([
|
||||
np.where(np.bincount(inp[j, :], minlength=size) > 0, 1, 0)
|
||||
for j in range(num_rows)
|
||||
],
|
||||
axis=0), (num_rows, size))
|
||||
x = ragged_tensor.RaggedTensor.from_tensor(inp)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.ragged_bincount(
|
||||
splits=x.row_splits,
|
||||
values=x.values,
|
||||
weights=[],
|
||||
size=size,
|
||||
binary_count=True)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -1072,6 +1072,10 @@ tf_module {
|
||||
name: "DeleteSessionTensor"
|
||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DenseBincount"
|
||||
argspec: "args=[\'input\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DenseCountSparseOutput"
|
||||
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||
@ -3064,6 +3068,10 @@ tf_module {
|
||||
name: "RGBToHSV"
|
||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedBincount"
|
||||
argspec: "args=[\'splits\', \'values\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedCountSparseOutput"
|
||||
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||
@ -4072,6 +4080,10 @@ tf_module {
|
||||
name: "SparseApplyRMSProp"
|
||||
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseBincount"
|
||||
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseConcat"
|
||||
argspec: "args=[\'indices\', \'values\', \'shapes\', \'concat_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1072,6 +1072,10 @@ tf_module {
|
||||
name: "DeleteSessionTensor"
|
||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DenseBincount"
|
||||
argspec: "args=[\'input\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DenseCountSparseOutput"
|
||||
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||
@ -3064,6 +3068,10 @@ tf_module {
|
||||
name: "RGBToHSV"
|
||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedBincount"
|
||||
argspec: "args=[\'splits\', \'values\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedCountSparseOutput"
|
||||
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||
@ -4072,6 +4080,10 @@ tf_module {
|
||||
name: "SparseApplyRMSProp"
|
||||
argspec: "args=[\'var\', \'ms\', \'mom\', \'lr\', \'rho\', \'momentum\', \'epsilon\', \'grad\', \'indices\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseBincount"
|
||||
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseConcat"
|
||||
argspec: "args=[\'indices\', \'values\', \'shapes\', \'concat_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user