Add tf.sparse.bincount op.
PiperOrigin-RevId: 309334052 Change-Id: Ic00b2b66deb467dd901e0e7eaa7152e6d1dc18b0
This commit is contained in:
parent
387429dd3b
commit
b7c40bc3cc
@ -619,6 +619,7 @@ tf_gen_op_libs(
|
|||||||
"clustering_ops",
|
"clustering_ops",
|
||||||
"collective_ops",
|
"collective_ops",
|
||||||
"control_flow_ops",
|
"control_flow_ops",
|
||||||
|
"count_ops",
|
||||||
"ctc_ops",
|
"ctc_ops",
|
||||||
"data_flow_ops",
|
"data_flow_ops",
|
||||||
"dataset_ops",
|
"dataset_ops",
|
||||||
@ -847,6 +848,7 @@ cc_library(
|
|||||||
":clustering_ops_op_lib",
|
":clustering_ops_op_lib",
|
||||||
":collective_ops_op_lib",
|
":collective_ops_op_lib",
|
||||||
":control_flow_ops_op_lib",
|
":control_flow_ops_op_lib",
|
||||||
|
":count_ops_op_lib",
|
||||||
":ctc_ops_op_lib",
|
":ctc_ops_op_lib",
|
||||||
":cudnn_rnn_ops_op_lib",
|
":cudnn_rnn_ops_op_lib",
|
||||||
":data_flow_ops_op_lib",
|
":data_flow_ops_op_lib",
|
||||||
@ -1006,6 +1008,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:collective_ops",
|
"//tensorflow/core/kernels:collective_ops",
|
||||||
"//tensorflow/core/kernels:constant_op",
|
"//tensorflow/core/kernels:constant_op",
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
|
"//tensorflow/core/kernels:count_ops",
|
||||||
"//tensorflow/core/kernels:ctc_ops",
|
"//tensorflow/core/kernels:ctc_ops",
|
||||||
"//tensorflow/core/kernels:data_flow",
|
"//tensorflow/core/kernels:data_flow",
|
||||||
"//tensorflow/core/kernels:decode_proto_op",
|
"//tensorflow/core/kernels:decode_proto_op",
|
||||||
|
@ -0,0 +1,68 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "DenseCountSparseOutput"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "values"
|
||||||
|
description: <<END
|
||||||
|
int32 or int64; Tensor containing data to count.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "weights"
|
||||||
|
description: <<END
|
||||||
|
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_indices"
|
||||||
|
description: <<END
|
||||||
|
int64; indices tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_values"
|
||||||
|
description: <<END
|
||||||
|
int64 or float32; values tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_dense_shape"
|
||||||
|
description: <<END
|
||||||
|
int64; shape tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
description: <<END
|
||||||
|
dtype; dtype of the input values tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "minlength"
|
||||||
|
description: <<END
|
||||||
|
int32; minimum value to count. Can be set to -1 for no minimum.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "maxlength"
|
||||||
|
description: <<END
|
||||||
|
int32; maximum value to count. Can be set to -1 for no maximum.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "binary_count"
|
||||||
|
description: <<END
|
||||||
|
bool; whether to output the number of occurrences of each value or 1.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_type"
|
||||||
|
description: <<END
|
||||||
|
dtype; dtype of the output values tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Performs sparse-output bin counting for a tf.tensor input."
|
||||||
|
description: <<END
|
||||||
|
Counts the number of times each value occurs in the input.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,74 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "RaggedCountSparseOutput"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "splits"
|
||||||
|
description: <<END
|
||||||
|
int64; Tensor containing the row splits of the ragged tensor to count.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "values"
|
||||||
|
description: <<END
|
||||||
|
int32 or int64; Tensor containing values of the sparse tensor to count.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "weights"
|
||||||
|
description: <<END
|
||||||
|
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_indices"
|
||||||
|
description: <<END
|
||||||
|
int64; indices tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_values"
|
||||||
|
description: <<END
|
||||||
|
int64 or float32; values tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_dense_shape"
|
||||||
|
description: <<END
|
||||||
|
int64; shape tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
description: <<END
|
||||||
|
dtype; dtype of the input values tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "minlength"
|
||||||
|
description: <<END
|
||||||
|
int32; minimum value to count. Can be set to -1 for no minimum.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "maxlength"
|
||||||
|
description: <<END
|
||||||
|
int32; maximum value to count. Can be set to -1 for no maximum.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "binary_count"
|
||||||
|
description: <<END
|
||||||
|
bool; whether to output the number of occurrences of each value or 1.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_type"
|
||||||
|
description: <<END
|
||||||
|
dtype; dtype of the output values tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Performs sparse-output bin counting for a ragged tensor input."
|
||||||
|
description: <<END
|
||||||
|
Counts the number of times each value occurs in the input.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,80 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseCountSparseOutput"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: <<END
|
||||||
|
int64; Tensor containing the indices of the sparse tensor to count.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "values"
|
||||||
|
description: <<END
|
||||||
|
int32 or int64; Tensor containing values of the sparse tensor to count.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "dense_shape"
|
||||||
|
description: <<END
|
||||||
|
int64; Tensor containing the dense shape of the sparse tensor to count.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "weights"
|
||||||
|
description: <<END
|
||||||
|
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_indices"
|
||||||
|
description: <<END
|
||||||
|
int64; indices tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_values"
|
||||||
|
description: <<END
|
||||||
|
int64 or float32; values tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_dense_shape"
|
||||||
|
description: <<END
|
||||||
|
int64; shape tensor for the resulting sparse tensor object.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
description: <<END
|
||||||
|
dtype; dtype of the input values tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "minlength"
|
||||||
|
description: <<END
|
||||||
|
int32; minimum value to count. Can be set to -1 for no minimum.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "maxlength"
|
||||||
|
description: <<END
|
||||||
|
int32; maximum value to count. Can be set to -1 for no maximum.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "binary_count"
|
||||||
|
description: <<END
|
||||||
|
bool; whether to output the number of occurrences of each value or 1.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_type"
|
||||||
|
description: <<END
|
||||||
|
dtype; dtype of the output values tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Performs sparse-output bin counting for a sparse tensor input."
|
||||||
|
description: <<END
|
||||||
|
Counts the number of times each value occurs in the input.
|
||||||
|
END
|
||||||
|
}
|
@ -5668,6 +5668,32 @@ tf_kernel_library(
|
|||||||
deps = STATE_DEPS,
|
deps = STATE_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "count_ops",
|
||||||
|
prefix = "count_ops",
|
||||||
|
deps = STATE_DEPS + [
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"//tensorflow/core/framework:op_requires",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "count_ops_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["count_ops_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":count_ops",
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "scatter_nd_op",
|
name = "scatter_nd_op",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
358
tensorflow/core/kernels/count_ops.cc
Normal file
358
tensorflow/core/kernels/count_ops.cc
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/op_requires.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using BatchedIntMap = std::vector<absl::flat_hash_map<int64, int64>>;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// TODO(momernick): Extend this function to work with outputs of rank > 2.
|
||||||
|
Status OutputSparse(const BatchedIntMap& per_batch_counts, int num_values,
|
||||||
|
bool is_1d, OpKernelContext* context) {
|
||||||
|
int total_values = 0;
|
||||||
|
int num_batches = per_batch_counts.size();
|
||||||
|
for (const auto& per_batch_count : per_batch_counts) {
|
||||||
|
total_values += per_batch_count.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor* indices;
|
||||||
|
int inner_dim = is_1d ? 1 : 2;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||||
|
0, TensorShape({total_values, inner_dim}), &indices));
|
||||||
|
|
||||||
|
Tensor* values;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(1, TensorShape({total_values}), &values));
|
||||||
|
|
||||||
|
auto output_indices = indices->matrix<int64>();
|
||||||
|
auto output_values = values->flat<int64>();
|
||||||
|
int64 value_loc = 0;
|
||||||
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
|
const auto& per_batch_count = per_batch_counts[b];
|
||||||
|
std::vector<std::pair<int, int>> pairs(per_batch_count.begin(),
|
||||||
|
per_batch_count.end());
|
||||||
|
std::sort(pairs.begin(), pairs.end());
|
||||||
|
for (const auto& x : pairs) {
|
||||||
|
if (is_1d) {
|
||||||
|
output_indices(value_loc, 0) = x.first;
|
||||||
|
} else {
|
||||||
|
output_indices(value_loc, 0) = b;
|
||||||
|
output_indices(value_loc, 1) = x.first;
|
||||||
|
}
|
||||||
|
output_values(value_loc) = x.second;
|
||||||
|
++value_loc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor* dense_shape;
|
||||||
|
if (is_1d) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(2, TensorShape({1}), &dense_shape));
|
||||||
|
dense_shape->flat<int64>().data()[0] = num_values;
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(2, TensorShape({2}), &dense_shape));
|
||||||
|
dense_shape->flat<int64>().data()[0] = num_batches;
|
||||||
|
dense_shape->flat<int64>().data()[1] = num_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OutputWeightedSparse(const BatchedIntMap& per_batch_counts,
|
||||||
|
int num_values, const Tensor& weights, bool is_1d,
|
||||||
|
OpKernelContext* context) {
|
||||||
|
if (!TensorShapeUtils::IsVector(weights.shape())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Weights must be a 1-dimensional tensor. Got: ",
|
||||||
|
weights.shape().DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_values > weights.dim_size(0)) {
|
||||||
|
return errors::InvalidArgument("The maximum array value was ", num_values,
|
||||||
|
", but the weight array has size ",
|
||||||
|
weights.shape().DebugString());
|
||||||
|
}
|
||||||
|
auto weight_values = weights.flat<float>();
|
||||||
|
|
||||||
|
int total_values = 0;
|
||||||
|
int num_batches = per_batch_counts.size();
|
||||||
|
for (const auto& per_batch_count : per_batch_counts) {
|
||||||
|
total_values += per_batch_count.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor* indices;
|
||||||
|
int inner_dim = is_1d ? 1 : 2;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||||
|
0, TensorShape({total_values, inner_dim}), &indices));
|
||||||
|
|
||||||
|
Tensor* values;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(1, TensorShape({total_values}), &values));
|
||||||
|
|
||||||
|
auto output_indices = indices->matrix<int64>();
|
||||||
|
auto output_values = values->flat<float>();
|
||||||
|
int64 value_loc = 0;
|
||||||
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
|
const auto& per_batch_count = per_batch_counts[b];
|
||||||
|
std::vector<std::pair<int, int>> pairs(per_batch_count.begin(),
|
||||||
|
per_batch_count.end());
|
||||||
|
std::sort(pairs.begin(), pairs.end());
|
||||||
|
for (const auto& x : pairs) {
|
||||||
|
if (is_1d) {
|
||||||
|
output_indices(value_loc, 0) = x.first;
|
||||||
|
} else {
|
||||||
|
output_indices(value_loc, 0) = b;
|
||||||
|
output_indices(value_loc, 1) = x.first;
|
||||||
|
}
|
||||||
|
output_values(value_loc) = x.second * weight_values(x.first);
|
||||||
|
++value_loc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor* dense_shape;
|
||||||
|
if (is_1d) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(2, TensorShape({1}), &dense_shape));
|
||||||
|
dense_shape->flat<int64>().data()[0] = num_values;
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(2, TensorShape({2}), &dense_shape));
|
||||||
|
dense_shape->flat<int64>().data()[0] = num_batches;
|
||||||
|
dense_shape->flat<int64>().data()[1] = num_values;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T GetOutputSize(T max_seen, T max_length, T min_length) {
|
||||||
|
return max_length > 0 ? max_length : std::max((max_seen + 1), min_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class DenseCount : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor& data = context->input(0);
|
||||||
|
const Tensor& weights = context->input(1);
|
||||||
|
bool use_weights = weights.NumElements() > 0;
|
||||||
|
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
TensorShapeUtils::IsVector(data.shape()) ||
|
||||||
|
TensorShapeUtils::IsMatrix(data.shape()),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input must be a 1 or 2-dimensional tensor. Got: ",
|
||||||
|
data.shape().DebugString()));
|
||||||
|
|
||||||
|
bool is_1d = TensorShapeUtils::IsVector(data.shape());
|
||||||
|
int negative_valued_axis = -1;
|
||||||
|
int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
|
||||||
|
|
||||||
|
int num_batch_elements = 1;
|
||||||
|
for (int i = 0; i < num_batch_dimensions; ++i) {
|
||||||
|
num_batch_elements *= data.shape().dim_size(i);
|
||||||
|
}
|
||||||
|
int num_value_elements = data.shape().num_elements() / num_batch_elements;
|
||||||
|
auto per_batch_counts = BatchedIntMap(num_batch_elements);
|
||||||
|
T max_value = 0;
|
||||||
|
|
||||||
|
const auto data_values = data.flat<T>();
|
||||||
|
int i = 0;
|
||||||
|
for (int b = 0; b < num_batch_elements; ++b) {
|
||||||
|
for (int v = 0; v < num_value_elements; ++v) {
|
||||||
|
const auto& value = data_values(i);
|
||||||
|
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
|
||||||
|
if (binary_count_) {
|
||||||
|
(per_batch_counts[b])[value] = 1;
|
||||||
|
} else {
|
||||||
|
(per_batch_counts[b])[value]++;
|
||||||
|
}
|
||||||
|
if (value > max_value) {
|
||||||
|
max_value = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
|
||||||
|
if (use_weights) {
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
OutputWeightedSparse(per_batch_counts, num_output_values,
|
||||||
|
weights, is_1d, context));
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
|
||||||
|
is_1d, context));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
T minlength_;
|
||||||
|
T maxlength_;
|
||||||
|
bool binary_count_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class SparseCount : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor& indices = context->input(0);
|
||||||
|
const Tensor& values = context->input(1);
|
||||||
|
const Tensor& shape = context->input(2);
|
||||||
|
const Tensor& weights = context->input(3);
|
||||||
|
bool use_weights = weights.NumElements() > 0;
|
||||||
|
|
||||||
|
bool is_1d = shape.NumElements() == 1;
|
||||||
|
const auto indices_values = indices.matrix<int64>();
|
||||||
|
const auto values_values = values.flat<T>();
|
||||||
|
|
||||||
|
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
|
||||||
|
int num_values = values.NumElements();
|
||||||
|
|
||||||
|
auto per_batch_counts = BatchedIntMap(num_batches);
|
||||||
|
T max_value = 0;
|
||||||
|
|
||||||
|
for (int idx = 0; idx < num_values; ++idx) {
|
||||||
|
int batch = is_1d ? 0 : indices_values(idx, 0);
|
||||||
|
const auto& value = values_values(idx);
|
||||||
|
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
|
||||||
|
if (binary_count_) {
|
||||||
|
(per_batch_counts[batch])[value] = 1;
|
||||||
|
} else {
|
||||||
|
(per_batch_counts[batch])[value]++;
|
||||||
|
}
|
||||||
|
if (value > max_value) {
|
||||||
|
max_value = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
|
||||||
|
if (use_weights) {
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
OutputWeightedSparse(per_batch_counts, num_output_values,
|
||||||
|
weights, is_1d, context));
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
|
||||||
|
is_1d, context));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
T minlength_;
|
||||||
|
T maxlength_;
|
||||||
|
bool binary_count_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class RaggedCount : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor& splits = context->input(0);
|
||||||
|
const Tensor& values = context->input(1);
|
||||||
|
const Tensor& weights = context->input(2);
|
||||||
|
bool use_weights = weights.NumElements() > 0;
|
||||||
|
|
||||||
|
const auto splits_values = splits.flat<int64>();
|
||||||
|
const auto values_values = values.flat<T>();
|
||||||
|
int num_batches = splits.NumElements() - 1;
|
||||||
|
int num_values = values.NumElements();
|
||||||
|
|
||||||
|
auto per_batch_counts = BatchedIntMap(num_batches);
|
||||||
|
T max_value = 0;
|
||||||
|
int batch_idx = 0;
|
||||||
|
|
||||||
|
for (int idx = 0; idx < num_values; ++idx) {
|
||||||
|
while (idx >= splits_values(batch_idx)) {
|
||||||
|
batch_idx++;
|
||||||
|
}
|
||||||
|
const auto& value = values_values(idx);
|
||||||
|
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
|
||||||
|
if (binary_count_) {
|
||||||
|
(per_batch_counts[batch_idx - 1])[value] = 1;
|
||||||
|
} else {
|
||||||
|
(per_batch_counts[batch_idx - 1])[value]++;
|
||||||
|
}
|
||||||
|
if (value > max_value) {
|
||||||
|
max_value = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
|
||||||
|
if (use_weights) {
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
OutputWeightedSparse(per_batch_counts, num_output_values,
|
||||||
|
weights, false, context));
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
|
||||||
|
false, context));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
T minlength_;
|
||||||
|
T maxlength_;
|
||||||
|
bool binary_count_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(TYPE) \
|
||||||
|
\
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \
|
||||||
|
.TypeConstraint<TYPE>("T") \
|
||||||
|
.Device(DEVICE_CPU), \
|
||||||
|
DenseCount<TYPE>) \
|
||||||
|
\
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \
|
||||||
|
.TypeConstraint<TYPE>("T") \
|
||||||
|
.Device(DEVICE_CPU), \
|
||||||
|
SparseCount<TYPE>) \
|
||||||
|
\
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \
|
||||||
|
.TypeConstraint<TYPE>("T") \
|
||||||
|
.Device(DEVICE_CPU), \
|
||||||
|
RaggedCount<TYPE>)
|
||||||
|
|
||||||
|
REGISTER(int32);
|
||||||
|
REGISTER(int64);
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
47
tensorflow/core/kernels/count_ops_test.cc
Normal file
47
tensorflow/core/kernels/count_ops_test.cc
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST_F(OpsTestBase, DenseCountSparseOutputShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("DenseCountSparseOutput");
|
||||||
|
INFER_OK(op, "[?];?", "[?,1];[?];[1]");
|
||||||
|
INFER_OK(op, "[?,?];?", "[?,2];[?];[2]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpsTestBase, SparseCountSparseOutputShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseCountSparseOutput");
|
||||||
|
INFER_OK(op, "[?,1];?;?;?", "[?,d0_1];[?];[d0_1]");
|
||||||
|
INFER_OK(op, "[?,2];?;?;?", "[?,d0_1];[?];[d0_1]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpsTestBase, RaggedCountSparseOutputShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("RaggedCountSparseOutput");
|
||||||
|
INFER_OK(op, "?;[?];?", "[?,2];[?];[2]");
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
97
tensorflow/core/ops/count_ops.cc
Normal file
97
tensorflow/core/ops/count_ops.cc
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using shape_inference::DimensionHandle;
|
||||||
|
using shape_inference::InferenceContext;
|
||||||
|
|
||||||
|
Status DenseCountSparseOutputShapeFn(InferenceContext *c) {
|
||||||
|
int32 rank = c->Rank(c->input(0));
|
||||||
|
DimensionHandle nvals = c->UnknownDim();
|
||||||
|
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
|
||||||
|
c->set_output(1, c->Vector(nvals)); // out.values
|
||||||
|
c->set_output(2, c->Vector(rank)); // out.dense_shape
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SparseCountSparseOutputShapeFn(InferenceContext *c) {
|
||||||
|
DimensionHandle rank = c->Dim(c->input(0), 1);
|
||||||
|
DimensionHandle nvals = c->UnknownDim();
|
||||||
|
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
|
||||||
|
c->set_output(1, c->Vector(nvals)); // out.values
|
||||||
|
c->set_output(2, c->Vector(rank)); // out.dense_shape
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RaggedCountSparseOutputShapeFn(InferenceContext *c) {
|
||||||
|
int32 rank = c->Rank(c->input(1));
|
||||||
|
if (rank != c->kUnknownRank) {
|
||||||
|
++rank; // Add the ragged dimension
|
||||||
|
}
|
||||||
|
DimensionHandle nvals = c->UnknownDim();
|
||||||
|
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
|
||||||
|
c->set_output(1, c->Vector(nvals)); // out.values
|
||||||
|
c->set_output(2, c->Vector(rank)); // out.dense_shape
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_OP("DenseCountSparseOutput")
|
||||||
|
.Input("values: T")
|
||||||
|
.Input("weights: float")
|
||||||
|
.Attr("T: {int32, int64}")
|
||||||
|
.Attr("minlength: int >= -1 = -1")
|
||||||
|
.Attr("maxlength: int >= -1 = -1")
|
||||||
|
.Attr("binary_count: bool")
|
||||||
|
.Attr("output_type: {int64, float}")
|
||||||
|
.SetShapeFn(DenseCountSparseOutputShapeFn)
|
||||||
|
.Output("output_indices: int64")
|
||||||
|
.Output("output_values: output_type")
|
||||||
|
.Output("output_dense_shape: int64");
|
||||||
|
|
||||||
|
REGISTER_OP("SparseCountSparseOutput")
|
||||||
|
.Input("indices: int64")
|
||||||
|
.Input("values: T")
|
||||||
|
.Input("dense_shape: int64")
|
||||||
|
.Input("weights: float")
|
||||||
|
.Attr("T: {int32, int64}")
|
||||||
|
.Attr("minlength: int >= -1 = -1")
|
||||||
|
.Attr("maxlength: int >= -1 = -1")
|
||||||
|
.Attr("binary_count: bool")
|
||||||
|
.Attr("output_type: {int64, float}")
|
||||||
|
.SetShapeFn(SparseCountSparseOutputShapeFn)
|
||||||
|
.Output("output_indices: int64")
|
||||||
|
.Output("output_values: output_type")
|
||||||
|
.Output("output_dense_shape: int64");
|
||||||
|
|
||||||
|
REGISTER_OP("RaggedCountSparseOutput")
|
||||||
|
.Input("splits: int64")
|
||||||
|
.Input("values: T")
|
||||||
|
.Input("weights: float")
|
||||||
|
.Attr("T: {int32, int64}")
|
||||||
|
.Attr("minlength: int >= -1 = -1")
|
||||||
|
.Attr("maxlength: int >= -1 = -1")
|
||||||
|
.Attr("binary_count: bool")
|
||||||
|
.Attr("output_type: {int64, float}")
|
||||||
|
.SetShapeFn(RaggedCountSparseOutputShapeFn)
|
||||||
|
.Output("output_indices: int64")
|
||||||
|
.Output("output_values: output_type")
|
||||||
|
.Output("output_dense_shape: int64");
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -1,4 +1,3 @@
|
|||||||
# Python support for TensorFlow.
|
|
||||||
#
|
#
|
||||||
# Public targets:
|
# Public targets:
|
||||||
# ":platform" - Low-level and platform-specific Python code.
|
# ":platform" - Low-level and platform-specific Python code.
|
||||||
@ -135,6 +134,7 @@ py_library(
|
|||||||
":_pywrap_utils",
|
":_pywrap_utils",
|
||||||
":array_ops",
|
":array_ops",
|
||||||
":audio_ops_gen",
|
":audio_ops_gen",
|
||||||
|
":bincount",
|
||||||
":bitwise_ops",
|
":bitwise_ops",
|
||||||
":boosted_trees_ops",
|
":boosted_trees_ops",
|
||||||
":check_ops",
|
":check_ops",
|
||||||
@ -2898,6 +2898,11 @@ tf_gen_op_wrapper_private_py(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_private_py(
|
||||||
|
name = "count_ops_gen",
|
||||||
|
visibility = ["//learning/brain/python/ops:__pkg__"],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
name = "parsing_ops_gen",
|
name = "parsing_ops_gen",
|
||||||
visibility = ["//learning/brain/python/ops:__pkg__"],
|
visibility = ["//learning/brain/python/ops:__pkg__"],
|
||||||
@ -3463,6 +3468,28 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "bincount",
|
||||||
|
srcs = ["ops/bincount.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":count_ops_gen",
|
||||||
|
":framework",
|
||||||
|
":framework_for_generated_wrappers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "bincount_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["ops/bincount_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":bincount",
|
||||||
|
":platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "ctc_ops",
|
name = "ctc_ops",
|
||||||
srcs = ["ops/ctc_ops.py"],
|
srcs = ["ops/ctc_ops.py"],
|
||||||
|
@ -85,6 +85,7 @@ from tensorflow.python import keras
|
|||||||
from tensorflow.python.feature_column import feature_column_lib as feature_column
|
from tensorflow.python.feature_column import feature_column_lib as feature_column
|
||||||
from tensorflow.python.layers import layers
|
from tensorflow.python.layers import layers
|
||||||
from tensorflow.python.module import module
|
from tensorflow.python.module import module
|
||||||
|
from tensorflow.python.ops import bincount
|
||||||
from tensorflow.python.ops import bitwise_ops as bitwise
|
from tensorflow.python.ops import bitwise_ops as bitwise
|
||||||
from tensorflow.python.ops import gradient_checker_v2
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.ops import image_ops as image
|
from tensorflow.python.ops import image_ops as image
|
||||||
|
199
tensorflow/python/ops/bincount.py
Normal file
199
tensorflow/python/ops/bincount.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# maxlengthations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""tf.sparse.bincount ops."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gen_count_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("sparse.bincount")
|
||||||
|
def sparse_bincount(values,
|
||||||
|
weights=None,
|
||||||
|
axis=0,
|
||||||
|
minlength=None,
|
||||||
|
maxlength=None,
|
||||||
|
binary_count=False,
|
||||||
|
name=None):
|
||||||
|
"""Count the number of times an integer value appears in a tensor.
|
||||||
|
|
||||||
|
This op takes an N-dimensional `Tensor`, `RaggedTensor`, or `SparseTensor`,
|
||||||
|
and returns an N-dimensional int64 SparseTensor where element
|
||||||
|
`[i0...i[axis], j]` contains the number of times the value `j` appears in
|
||||||
|
slice `[i0...i[axis], :]` of the input tensor. Currently, only N=0 and
|
||||||
|
N=-1 are supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: A Tensor, RaggedTensor, or SparseTensor whose values should be
|
||||||
|
counted. These tensors must have a rank of 1 or 2.
|
||||||
|
weights: A 1-dimensional Tensor of weights. If specified, the input array is
|
||||||
|
weighted by the weight array, i.e. if a value `n` is found at position
|
||||||
|
`i`, `out[n]` will be increased by `weight[i]` instead of 1.
|
||||||
|
axis: The axis to slice over. Axes at and below `axis` will be flattened
|
||||||
|
before bin counting. Currently, only `0`, and `-1` are supported. If None,
|
||||||
|
all axes will be flattened (identical to passing `0`).
|
||||||
|
minlength: If given, skips `values` that are less than `minlength`, and
|
||||||
|
ensures that the output has a `dense_shape` of at least `minlength` in the
|
||||||
|
inner dimension.
|
||||||
|
maxlength: If given, skips `values` that are greater than or equal to
|
||||||
|
`maxlength`, and ensures that the output has a `dense_shape` of at most
|
||||||
|
`maxlength` in the inner dimension.
|
||||||
|
binary_count: Whether to do a binary count. When True, this op will return 1
|
||||||
|
for any value that exists instead of counting the number of occurrences.
|
||||||
|
name: A name for this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SparseTensor with `output.shape = values.shape[:axis] + [N]`, where `N` is
|
||||||
|
* `maxlength` (if set);
|
||||||
|
* `minlength` (if set, and `minlength > reduce_max(values)`);
|
||||||
|
* `0` (if `values` is empty);
|
||||||
|
* `reduce_max(values) + 1` otherwise.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
**Bin-counting every item in individual batches**
|
||||||
|
|
||||||
|
This example takes an input (which could be a Tensor, RaggedTensor, or
|
||||||
|
SparseTensor) and returns a SparseTensor where the value of (i,j) is the
|
||||||
|
number of times value j appears in batch i.
|
||||||
|
|
||||||
|
>>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]]
|
||||||
|
>>> output = tf.sparse.bincount(data, axis=-1)
|
||||||
|
>>> print(output)
|
||||||
|
SparseTensor(indices=tf.Tensor(
|
||||||
|
[[ 0 10]
|
||||||
|
[ 0 20]
|
||||||
|
[ 0 30]
|
||||||
|
[ 1 11]
|
||||||
|
[ 1 101]
|
||||||
|
[ 1 10001]], shape=(6, 2), dtype=int64),
|
||||||
|
values=tf.Tensor([1 2 1 2 1 1], shape=(6,), dtype=int64),
|
||||||
|
dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64))
|
||||||
|
|
||||||
|
**Bin-counting with defined output shape**
|
||||||
|
|
||||||
|
This example takes an input (which could be a Tensor, RaggedTensor, or
|
||||||
|
SparseTensor) and returns a SparseTensor where the value of (i,j) is the
|
||||||
|
number of times value j appears in batch i. However, all values of j
|
||||||
|
above 'maxlength' are ignored. The dense_shape of the output sparse tensor
|
||||||
|
is set to 'minlength'. Note that, while the input is identical to the
|
||||||
|
example above, the value '10001' in batch item 2 is dropped, and the
|
||||||
|
dense shape is [2, 500] instead of [2,10002] or [2, 102].
|
||||||
|
|
||||||
|
>>> minlength = maxlength = 500
|
||||||
|
>>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]]
|
||||||
|
>>> output = tf.sparse.bincount(
|
||||||
|
... data, axis=-1, minlength=minlength, maxlength=maxlength)
|
||||||
|
>>> print(output)
|
||||||
|
SparseTensor(indices=tf.Tensor(
|
||||||
|
[[ 0 10]
|
||||||
|
[ 0 20]
|
||||||
|
[ 0 30]
|
||||||
|
[ 1 11]
|
||||||
|
[ 1 101]], shape=(5, 2), dtype=int64),
|
||||||
|
values=tf.Tensor([1 2 1 2 1], shape=(5,), dtype=int64),
|
||||||
|
dense_shape=tf.Tensor([ 2 500], shape=(2,), dtype=int64))
|
||||||
|
|
||||||
|
**Binary bin-counting**
|
||||||
|
|
||||||
|
This example takes an input (which could be a Tensor, RaggedTensor, or
|
||||||
|
SparseTensor) and returns a SparseTensor where (i,j) is 1 if the value j
|
||||||
|
appears in batch i at least once and is 0 otherwise. Note that, even though
|
||||||
|
some values (like 20 in batch 1 and 11 in batch 2) appear more than once,
|
||||||
|
the 'values' tensor is all 1s.
|
||||||
|
|
||||||
|
>>> dense = [[10, 20, 30, 20], [11, 101, 11, 10001]]
|
||||||
|
>>> output = tf.sparse.bincount(dense, binary_count=True, axis=-1)
|
||||||
|
>>> print(output)
|
||||||
|
SparseTensor(indices=tf.Tensor(
|
||||||
|
[[ 0 10]
|
||||||
|
[ 0 20]
|
||||||
|
[ 0 30]
|
||||||
|
[ 1 11]
|
||||||
|
[ 1 101]
|
||||||
|
[ 1 10001]], shape=(6, 2), dtype=int64),
|
||||||
|
values=tf.Tensor([1 1 1 1 1 1], shape=(6,), dtype=int64),
|
||||||
|
dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64))
|
||||||
|
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, "count", [values, weights]):
|
||||||
|
if not isinstance(values, sparse_tensor.SparseTensor):
|
||||||
|
values = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||||
|
values, name="values")
|
||||||
|
|
||||||
|
if weights is not None and binary_count:
|
||||||
|
raise ValueError("binary_count and weights are mutually exclusive.")
|
||||||
|
|
||||||
|
if weights is None:
|
||||||
|
weights = []
|
||||||
|
output_type = dtypes.int64
|
||||||
|
else:
|
||||||
|
output_type = dtypes.float32
|
||||||
|
|
||||||
|
if axis is None:
|
||||||
|
axis = 0
|
||||||
|
|
||||||
|
if axis not in [0, -1]:
|
||||||
|
raise ValueError("Unsupported axis value %s. Only 0 and -1 are currently "
|
||||||
|
"supported." % axis)
|
||||||
|
|
||||||
|
minlength_value = minlength if minlength is not None else -1
|
||||||
|
maxlength_value = maxlength if maxlength is not None else -1
|
||||||
|
|
||||||
|
if axis == 0:
|
||||||
|
if isinstance(values,
|
||||||
|
(sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
|
||||||
|
values = values.values
|
||||||
|
else:
|
||||||
|
values = array_ops.reshape(values, [-1])
|
||||||
|
|
||||||
|
if isinstance(values, sparse_tensor.SparseTensor):
|
||||||
|
c_ind, c_val, c_shape = gen_count_ops.sparse_count_sparse_output(
|
||||||
|
values.indices,
|
||||||
|
values.values,
|
||||||
|
values.dense_shape,
|
||||||
|
weights=weights,
|
||||||
|
minlength=minlength_value,
|
||||||
|
maxlength=maxlength_value,
|
||||||
|
binary_count=binary_count,
|
||||||
|
output_type=output_type)
|
||||||
|
elif isinstance(values, ragged_tensor.RaggedTensor):
|
||||||
|
c_ind, c_val, c_shape = gen_count_ops.ragged_count_sparse_output(
|
||||||
|
values.row_splits,
|
||||||
|
values.values,
|
||||||
|
weights=weights,
|
||||||
|
minlength=minlength_value,
|
||||||
|
maxlength=maxlength_value,
|
||||||
|
binary_count=binary_count,
|
||||||
|
output_type=output_type)
|
||||||
|
else:
|
||||||
|
c_ind, c_val, c_shape = gen_count_ops.dense_count_sparse_output(
|
||||||
|
values,
|
||||||
|
weights=weights,
|
||||||
|
minlength=minlength_value,
|
||||||
|
maxlength=maxlength_value,
|
||||||
|
binary_count=binary_count,
|
||||||
|
output_type=output_type)
|
||||||
|
|
||||||
|
return sparse_tensor.SparseTensor(c_ind, c_val, c_shape)
|
504
tensorflow/python/ops/bincount_test.py
Normal file
504
tensorflow/python/ops/bincount_test.py
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# maxlengthations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for bincount ops."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.ops import bincount
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class TestSparseCount(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
"testcase_name": "_no_maxlength",
|
||||||
|
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [2, 6]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_maxlength",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"maxlength": 7,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 2],
|
||||||
|
"expected_shape": [2, 7]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_minlength",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"minlength": 9,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
|
||||||
|
[1, 7]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [2, 9]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_minlength_larger_values",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"minlength": 3,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
|
||||||
|
[1, 7]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [2, 8]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_no_maxlength_binary",
|
||||||
|
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [2, 6],
|
||||||
|
"binary_count": True,
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_maxlength_binary",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"maxlength": 7,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [2, 7],
|
||||||
|
"binary_count": True,
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_minlength_binary",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"minlength": 9,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
|
||||||
|
[1, 7]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [2, 9],
|
||||||
|
"binary_count": True,
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_minlength_larger_values_binary",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"minlength": 3,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
|
||||||
|
[1, 7]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [2, 8],
|
||||||
|
"binary_count": True,
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_no_maxlength_weights",
|
||||||
|
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
|
||||||
|
"expected_values": [1, 2, 3, 8, 5],
|
||||||
|
"expected_shape": [2, 6],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_maxlength_weights",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"maxlength": 7,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
|
||||||
|
"expected_values": [1, 2, 3, 0.5, 8],
|
||||||
|
"expected_shape": [2, 7],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_minlength_weights",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"minlength": 9,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
|
||||||
|
[1, 7]],
|
||||||
|
"expected_values": [1, 2, 3, 7, 0.5, 8, 7],
|
||||||
|
"expected_shape": [2, 9],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_minlength_larger_values_weights",
|
||||||
|
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
|
||||||
|
"minlength": 3,
|
||||||
|
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
|
||||||
|
[1, 7]],
|
||||||
|
"expected_values": [1, 2, 3, 7, 0.5, 8, 7],
|
||||||
|
"expected_shape": [2, 8],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_1d",
|
||||||
|
"x": np.array([3, 2, 1, 1], dtype=np.int32),
|
||||||
|
"expected_indices": [[1], [2], [3]],
|
||||||
|
"expected_values": [2, 1, 1],
|
||||||
|
"expected_shape": [4]
|
||||||
|
}, {
|
||||||
|
"testcase_name": "_all_axes",
|
||||||
|
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
|
||||||
|
"expected_indices": [[1], [2], [3], [4], [5]],
|
||||||
|
"expected_values": [1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [6],
|
||||||
|
"axis": None
|
||||||
|
})
|
||||||
|
def test_dense_input(self,
|
||||||
|
x,
|
||||||
|
expected_indices,
|
||||||
|
expected_values,
|
||||||
|
expected_shape,
|
||||||
|
minlength=None,
|
||||||
|
maxlength=None,
|
||||||
|
binary_count=False,
|
||||||
|
weights=None,
|
||||||
|
axis=-1):
|
||||||
|
y = bincount.sparse_bincount(
|
||||||
|
x,
|
||||||
|
weights=weights,
|
||||||
|
minlength=minlength,
|
||||||
|
maxlength=maxlength,
|
||||||
|
binary_count=binary_count,
|
||||||
|
axis=axis)
|
||||||
|
self.assertAllEqual(expected_indices, y.indices)
|
||||||
|
self.assertAllEqual(expected_values, y.values)
|
||||||
|
self.assertAllEqual(expected_shape, y.dense_shape)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_no_maxlength",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 2, 1],
|
||||||
|
"expected_shape": [3, 6],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_maxlength",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 2, 1],
|
||||||
|
"expected_shape": [3, 7],
|
||||||
|
"maxlength":
|
||||||
|
7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_minlength",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [3, 9],
|
||||||
|
"minlength":
|
||||||
|
9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_minlength_larger_values",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [3, 8],
|
||||||
|
"minlength":
|
||||||
|
3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_no_maxlength_binary",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1],
|
||||||
|
"expected_shape": [3, 6],
|
||||||
|
"binary_count":
|
||||||
|
True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_maxlength_binary",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [0, 0, 7, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1],
|
||||||
|
"expected_shape": [3, 7],
|
||||||
|
"maxlength":
|
||||||
|
7,
|
||||||
|
"binary_count":
|
||||||
|
True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_minlength_binary",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [3, 9],
|
||||||
|
"minlength":
|
||||||
|
9,
|
||||||
|
"binary_count":
|
||||||
|
True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_minlength_larger_values_binary",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [3, 8],
|
||||||
|
"minlength":
|
||||||
|
3,
|
||||||
|
"binary_count":
|
||||||
|
True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_no_maxlength_weights",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 3, 8, 5],
|
||||||
|
"expected_shape": [3, 6],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_maxlength_weights",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [0, 0, 7, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 3, 8, 5],
|
||||||
|
"expected_shape": [3, 7],
|
||||||
|
"maxlength":
|
||||||
|
7,
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_minlength_weights",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 3, 7, 8, 5],
|
||||||
|
"expected_shape": [3, 9],
|
||||||
|
"minlength":
|
||||||
|
9,
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_minlength_larger_values_weights",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
|
||||||
|
"expected_values": [1, 3, 7, 8, 5],
|
||||||
|
"expected_shape": [3, 8],
|
||||||
|
"minlength":
|
||||||
|
3,
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_1d",
|
||||||
|
"x": np.array([3, 0, 1, 1], dtype=np.int32),
|
||||||
|
"expected_indices": [[1], [3]],
|
||||||
|
"expected_values": [2, 1],
|
||||||
|
"expected_shape": [4],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"_all_axes",
|
||||||
|
"x":
|
||||||
|
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
|
||||||
|
dtype=np.int32),
|
||||||
|
"expected_indices": [[1], [3], [4], [5]],
|
||||||
|
"expected_values": [1, 1, 2, 1],
|
||||||
|
"expected_shape": [6],
|
||||||
|
"axis":
|
||||||
|
None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_sparse_input(self,
|
||||||
|
x,
|
||||||
|
expected_indices,
|
||||||
|
expected_values,
|
||||||
|
expected_shape,
|
||||||
|
maxlength=None,
|
||||||
|
minlength=None,
|
||||||
|
binary_count=False,
|
||||||
|
weights=None,
|
||||||
|
axis=-1):
|
||||||
|
x_sparse = sparse_ops.from_dense(x)
|
||||||
|
y = bincount.sparse_bincount(
|
||||||
|
x_sparse,
|
||||||
|
weights=weights,
|
||||||
|
minlength=minlength,
|
||||||
|
maxlength=maxlength,
|
||||||
|
binary_count=binary_count,
|
||||||
|
axis=axis)
|
||||||
|
self.assertAllEqual(expected_indices, y.indices)
|
||||||
|
self.assertAllEqual(expected_values, y.values)
|
||||||
|
self.assertAllEqual(expected_shape, y.dense_shape)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
"testcase_name": "_no_maxlength",
|
||||||
|
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [5, 6],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_maxlength",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"maxlength": 7,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [5, 7],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_minlength",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"minlength": 9,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
|
||||||
|
[4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [5, 9],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_minlength_larger_values",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"minlength": 3,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
|
||||||
|
[4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [5, 8],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_no_maxlength_binary",
|
||||||
|
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [5, 6],
|
||||||
|
"binary_count": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_maxlength_binary",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"maxlength": 7,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [5, 7],
|
||||||
|
"binary_count": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_minlength_binary",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"minlength": 9,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
|
||||||
|
[4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [5, 9],
|
||||||
|
"binary_count": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_minlength_larger_values_binary",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"minlength": 3,
|
||||||
|
"binary_count": True,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
|
||||||
|
[4, 5]],
|
||||||
|
"expected_values": [1, 1, 1, 1, 1, 1, 1],
|
||||||
|
"expected_shape": [5, 8],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_no_maxlength_weights",
|
||||||
|
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
|
||||||
|
"expected_values": [0.5, 1, 3, 0.5, 8, 5],
|
||||||
|
"expected_shape": [5, 6],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_maxlength_weights",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"maxlength": 7,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
|
||||||
|
"expected_values": [0.5, 1, 3, 0.5, 8, 5],
|
||||||
|
"expected_shape": [5, 7],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_minlength_weights",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"minlength": 9,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
|
||||||
|
[4, 5]],
|
||||||
|
"expected_values": [0.5, 1, 3, 7, 0.5, 8, 5],
|
||||||
|
"expected_shape": [5, 9],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_minlength_larger_values_weights",
|
||||||
|
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
|
||||||
|
"minlength": 3,
|
||||||
|
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
|
||||||
|
[4, 5]],
|
||||||
|
"expected_values": [0.5, 1, 3, 7, 0.5, 8, 5],
|
||||||
|
"expected_shape": [5, 8],
|
||||||
|
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_1d",
|
||||||
|
"x": [3, 0, 1, 1],
|
||||||
|
"expected_indices": [[0], [1], [3]],
|
||||||
|
"expected_values": [1, 2, 1],
|
||||||
|
"expected_shape": [4],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "_all_axes",
|
||||||
|
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
|
||||||
|
"expected_indices": [[0], [1], [3], [4], [5]],
|
||||||
|
"expected_values": [2, 1, 1, 2, 1],
|
||||||
|
"expected_shape": [6],
|
||||||
|
"axis": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_ragged_input(self,
|
||||||
|
x,
|
||||||
|
expected_indices,
|
||||||
|
expected_values,
|
||||||
|
expected_shape,
|
||||||
|
maxlength=None,
|
||||||
|
minlength=None,
|
||||||
|
binary_count=False,
|
||||||
|
weights=None,
|
||||||
|
axis=-1):
|
||||||
|
x_ragged = ragged_factory_ops.constant(x)
|
||||||
|
y = bincount.sparse_bincount(
|
||||||
|
x_ragged,
|
||||||
|
weights=weights,
|
||||||
|
minlength=minlength,
|
||||||
|
maxlength=maxlength,
|
||||||
|
binary_count=binary_count,
|
||||||
|
axis=axis)
|
||||||
|
self.assertAllEqual(expected_indices, y.indices)
|
||||||
|
self.assertAllEqual(expected_values, y.values)
|
||||||
|
self.assertAllEqual(expected_shape, y.dense_shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -1076,6 +1076,10 @@ tf_module {
|
|||||||
name: "DeleteSessionTensor"
|
name: "DeleteSessionTensor"
|
||||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DenseCountSparseOutput"
|
||||||
|
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DenseToCSRSparseMatrix"
|
name: "DenseToCSRSparseMatrix"
|
||||||
argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -3060,6 +3064,10 @@ tf_module {
|
|||||||
name: "RGBToHSV"
|
name: "RGBToHSV"
|
||||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RaggedCountSparseOutput"
|
||||||
|
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedCross"
|
name: "RaggedCross"
|
||||||
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -4072,6 +4080,10 @@ tf_module {
|
|||||||
name: "SparseConditionalAccumulator"
|
name: "SparseConditionalAccumulator"
|
||||||
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "SparseCountSparseOutput"
|
||||||
|
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SparseCross"
|
name: "SparseCross"
|
||||||
argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -12,6 +12,10 @@ tf_module {
|
|||||||
name: "add"
|
name: "add"
|
||||||
argspec: "args=[\'a\', \'b\', \'threshold\', \'thresh\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'a\', \'b\', \'threshold\', \'thresh\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "bincount"
|
||||||
|
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "concat"
|
name: "concat"
|
||||||
argspec: "args=[\'axis\', \'sp_inputs\', \'name\', \'expand_nonconcat_dim\', \'concat_dim\', \'expand_nonconcat_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
|
argspec: "args=[\'axis\', \'sp_inputs\', \'name\', \'expand_nonconcat_dim\', \'concat_dim\', \'expand_nonconcat_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
|
||||||
@ -38,7 +42,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_dense"
|
name: "from_dense"
|
||||||
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "mask"
|
name: "mask"
|
||||||
|
@ -1076,6 +1076,10 @@ tf_module {
|
|||||||
name: "DeleteSessionTensor"
|
name: "DeleteSessionTensor"
|
||||||
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DenseCountSparseOutput"
|
||||||
|
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DenseToCSRSparseMatrix"
|
name: "DenseToCSRSparseMatrix"
|
||||||
argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -3060,6 +3064,10 @@ tf_module {
|
|||||||
name: "RGBToHSV"
|
name: "RGBToHSV"
|
||||||
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RaggedCountSparseOutput"
|
||||||
|
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedCross"
|
name: "RaggedCross"
|
||||||
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -4072,6 +4080,10 @@ tf_module {
|
|||||||
name: "SparseConditionalAccumulator"
|
name: "SparseConditionalAccumulator"
|
||||||
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "SparseCountSparseOutput"
|
||||||
|
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SparseCross"
|
name: "SparseCross"
|
||||||
argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -8,6 +8,10 @@ tf_module {
|
|||||||
name: "add"
|
name: "add"
|
||||||
argspec: "args=[\'a\', \'b\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
argspec: "args=[\'a\', \'b\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "bincount"
|
||||||
|
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "concat"
|
name: "concat"
|
||||||
argspec: "args=[\'axis\', \'sp_inputs\', \'expand_nonconcat_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'axis\', \'sp_inputs\', \'expand_nonconcat_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
@ -34,7 +38,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_dense"
|
name: "from_dense"
|
||||||
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "mask"
|
name: "mask"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user