Add tf.sparse.bincount op.

PiperOrigin-RevId: 309334052
Change-Id: Ic00b2b66deb467dd901e0e7eaa7152e6d1dc18b0
This commit is contained in:
A. Unique TensorFlower 2020-04-30 17:47:35 -07:00 committed by TensorFlower Gardener
parent 387429dd3b
commit b7c40bc3cc
16 changed files with 1519 additions and 3 deletions

View File

@ -619,6 +619,7 @@ tf_gen_op_libs(
"clustering_ops",
"collective_ops",
"control_flow_ops",
"count_ops",
"ctc_ops",
"data_flow_ops",
"dataset_ops",
@ -847,6 +848,7 @@ cc_library(
":clustering_ops_op_lib",
":collective_ops_op_lib",
":control_flow_ops_op_lib",
":count_ops_op_lib",
":ctc_ops_op_lib",
":cudnn_rnn_ops_op_lib",
":data_flow_ops_op_lib",
@ -1006,6 +1008,7 @@ cc_library(
"//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:count_ops",
"//tensorflow/core/kernels:ctc_ops",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:decode_proto_op",

View File

@ -0,0 +1,68 @@
op {
graph_op_name: "DenseCountSparseOutput"
visibility: HIDDEN
in_arg {
name: "values"
description: <<END
int32 or int64; Tensor containing data to count.
END
}
in_arg {
name: "weights"
description: <<END
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
END
}
out_arg {
name: "output_indices"
description: <<END
int64; indices tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_values"
description: <<END
int64 or float32; values tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_dense_shape"
description: <<END
int64; shape tensor for the resulting sparse tensor object.
END
}
attr {
name: "T"
description: <<END
dtype; dtype of the input values tensor.
END
}
attr {
name: "minlength"
description: <<END
int32; minimum value to count. Can be set to -1 for no minimum.
END
}
attr {
name: "maxlength"
description: <<END
int32; maximum value to count. Can be set to -1 for no maximum.
END
}
attr {
name: "binary_count"
description: <<END
bool; whether to output the number of occurrences of each value or 1.
END
}
attr {
name: "output_type"
description: <<END
dtype; dtype of the output values tensor.
END
}
summary: "Performs sparse-output bin counting for a tf.tensor input."
description: <<END
Counts the number of times each value occurs in the input.
END
}

View File

@ -0,0 +1,74 @@
op {
graph_op_name: "RaggedCountSparseOutput"
visibility: HIDDEN
in_arg {
name: "splits"
description: <<END
int64; Tensor containing the row splits of the ragged tensor to count.
END
}
in_arg {
name: "values"
description: <<END
int32 or int64; Tensor containing values of the sparse tensor to count.
END
}
in_arg {
name: "weights"
description: <<END
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
END
}
out_arg {
name: "output_indices"
description: <<END
int64; indices tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_values"
description: <<END
int64 or float32; values tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_dense_shape"
description: <<END
int64; shape tensor for the resulting sparse tensor object.
END
}
attr {
name: "T"
description: <<END
dtype; dtype of the input values tensor.
END
}
attr {
name: "minlength"
description: <<END
int32; minimum value to count. Can be set to -1 for no minimum.
END
}
attr {
name: "maxlength"
description: <<END
int32; maximum value to count. Can be set to -1 for no maximum.
END
}
attr {
name: "binary_count"
description: <<END
bool; whether to output the number of occurrences of each value or 1.
END
}
attr {
name: "output_type"
description: <<END
dtype; dtype of the output values tensor.
END
}
summary: "Performs sparse-output bin counting for a ragged tensor input."
description: <<END
Counts the number of times each value occurs in the input.
END
}

View File

@ -0,0 +1,80 @@
op {
graph_op_name: "SparseCountSparseOutput"
visibility: HIDDEN
in_arg {
name: "indices"
description: <<END
int64; Tensor containing the indices of the sparse tensor to count.
END
}
in_arg {
name: "values"
description: <<END
int32 or int64; Tensor containing values of the sparse tensor to count.
END
}
in_arg {
name: "dense_shape"
description: <<END
int64; Tensor containing the dense shape of the sparse tensor to count.
END
}
in_arg {
name: "weights"
description: <<END
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
END
}
out_arg {
name: "output_indices"
description: <<END
int64; indices tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_values"
description: <<END
int64 or float32; values tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_dense_shape"
description: <<END
int64; shape tensor for the resulting sparse tensor object.
END
}
attr {
name: "T"
description: <<END
dtype; dtype of the input values tensor.
END
}
attr {
name: "minlength"
description: <<END
int32; minimum value to count. Can be set to -1 for no minimum.
END
}
attr {
name: "maxlength"
description: <<END
int32; maximum value to count. Can be set to -1 for no maximum.
END
}
attr {
name: "binary_count"
description: <<END
bool; whether to output the number of occurrences of each value or 1.
END
}
attr {
name: "output_type"
description: <<END
dtype; dtype of the output values tensor.
END
}
summary: "Performs sparse-output bin counting for a sparse tensor input."
description: <<END
Counts the number of times each value occurs in the input.
END
}

View File

@ -5668,6 +5668,32 @@ tf_kernel_library(
deps = STATE_DEPS,
)
tf_kernel_library(
name = "count_ops",
prefix = "count_ops",
deps = STATE_DEPS + [
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/core/framework:op_requires",
],
)
tf_cc_test(
name = "count_ops_test",
size = "small",
srcs = ["count_ops_test.cc"],
deps = [
":count_ops",
":ops_testutil",
":ops_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_library(
name = "scatter_nd_op",
srcs = [

View File

@ -0,0 +1,358 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
using BatchedIntMap = std::vector<absl::flat_hash_map<int64, int64>>;
namespace {
// TODO(momernick): Extend this function to work with outputs of rank > 2.
Status OutputSparse(const BatchedIntMap& per_batch_counts, int num_values,
bool is_1d, OpKernelContext* context) {
int total_values = 0;
int num_batches = per_batch_counts.size();
for (const auto& per_batch_count : per_batch_counts) {
total_values += per_batch_count.size();
}
Tensor* indices;
int inner_dim = is_1d ? 1 : 2;
TF_RETURN_IF_ERROR(context->allocate_output(
0, TensorShape({total_values, inner_dim}), &indices));
Tensor* values;
TF_RETURN_IF_ERROR(
context->allocate_output(1, TensorShape({total_values}), &values));
auto output_indices = indices->matrix<int64>();
auto output_values = values->flat<int64>();
int64 value_loc = 0;
for (int b = 0; b < num_batches; ++b) {
const auto& per_batch_count = per_batch_counts[b];
std::vector<std::pair<int, int>> pairs(per_batch_count.begin(),
per_batch_count.end());
std::sort(pairs.begin(), pairs.end());
for (const auto& x : pairs) {
if (is_1d) {
output_indices(value_loc, 0) = x.first;
} else {
output_indices(value_loc, 0) = b;
output_indices(value_loc, 1) = x.first;
}
output_values(value_loc) = x.second;
++value_loc;
}
}
Tensor* dense_shape;
if (is_1d) {
TF_RETURN_IF_ERROR(
context->allocate_output(2, TensorShape({1}), &dense_shape));
dense_shape->flat<int64>().data()[0] = num_values;
} else {
TF_RETURN_IF_ERROR(
context->allocate_output(2, TensorShape({2}), &dense_shape));
dense_shape->flat<int64>().data()[0] = num_batches;
dense_shape->flat<int64>().data()[1] = num_values;
}
return Status::OK();
}
Status OutputWeightedSparse(const BatchedIntMap& per_batch_counts,
int num_values, const Tensor& weights, bool is_1d,
OpKernelContext* context) {
if (!TensorShapeUtils::IsVector(weights.shape())) {
return errors::InvalidArgument(
"Weights must be a 1-dimensional tensor. Got: ",
weights.shape().DebugString());
}
if (num_values > weights.dim_size(0)) {
return errors::InvalidArgument("The maximum array value was ", num_values,
", but the weight array has size ",
weights.shape().DebugString());
}
auto weight_values = weights.flat<float>();
int total_values = 0;
int num_batches = per_batch_counts.size();
for (const auto& per_batch_count : per_batch_counts) {
total_values += per_batch_count.size();
}
Tensor* indices;
int inner_dim = is_1d ? 1 : 2;
TF_RETURN_IF_ERROR(context->allocate_output(
0, TensorShape({total_values, inner_dim}), &indices));
Tensor* values;
TF_RETURN_IF_ERROR(
context->allocate_output(1, TensorShape({total_values}), &values));
auto output_indices = indices->matrix<int64>();
auto output_values = values->flat<float>();
int64 value_loc = 0;
for (int b = 0; b < num_batches; ++b) {
const auto& per_batch_count = per_batch_counts[b];
std::vector<std::pair<int, int>> pairs(per_batch_count.begin(),
per_batch_count.end());
std::sort(pairs.begin(), pairs.end());
for (const auto& x : pairs) {
if (is_1d) {
output_indices(value_loc, 0) = x.first;
} else {
output_indices(value_loc, 0) = b;
output_indices(value_loc, 1) = x.first;
}
output_values(value_loc) = x.second * weight_values(x.first);
++value_loc;
}
}
Tensor* dense_shape;
if (is_1d) {
TF_RETURN_IF_ERROR(
context->allocate_output(2, TensorShape({1}), &dense_shape));
dense_shape->flat<int64>().data()[0] = num_values;
} else {
TF_RETURN_IF_ERROR(
context->allocate_output(2, TensorShape({2}), &dense_shape));
dense_shape->flat<int64>().data()[0] = num_batches;
dense_shape->flat<int64>().data()[1] = num_values;
}
return Status::OK();
}
template <class T>
T GetOutputSize(T max_seen, T max_length, T min_length) {
return max_length > 0 ? max_length : std::max((max_seen + 1), min_length);
}
} // namespace
template <class T>
class DenseCount : public OpKernel {
public:
explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
}
void Compute(OpKernelContext* context) override {
const Tensor& data = context->input(0);
const Tensor& weights = context->input(1);
bool use_weights = weights.NumElements() > 0;
OP_REQUIRES(context,
TensorShapeUtils::IsVector(data.shape()) ||
TensorShapeUtils::IsMatrix(data.shape()),
errors::InvalidArgument(
"Input must be a 1 or 2-dimensional tensor. Got: ",
data.shape().DebugString()));
bool is_1d = TensorShapeUtils::IsVector(data.shape());
int negative_valued_axis = -1;
int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
int num_batch_elements = 1;
for (int i = 0; i < num_batch_dimensions; ++i) {
num_batch_elements *= data.shape().dim_size(i);
}
int num_value_elements = data.shape().num_elements() / num_batch_elements;
auto per_batch_counts = BatchedIntMap(num_batch_elements);
T max_value = 0;
const auto data_values = data.flat<T>();
int i = 0;
for (int b = 0; b < num_batch_elements; ++b) {
for (int v = 0; v < num_value_elements; ++v) {
const auto& value = data_values(i);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_count_) {
(per_batch_counts[b])[value] = 1;
} else {
(per_batch_counts[b])[value]++;
}
if (value > max_value) {
max_value = value;
}
}
++i;
}
}
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
if (use_weights) {
OP_REQUIRES_OK(context,
OutputWeightedSparse(per_batch_counts, num_output_values,
weights, is_1d, context));
} else {
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
is_1d, context));
}
}
private:
T minlength_;
T maxlength_;
bool binary_count_;
};
template <class T>
class SparseCount : public OpKernel {
public:
explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
}
void Compute(OpKernelContext* context) override {
const Tensor& indices = context->input(0);
const Tensor& values = context->input(1);
const Tensor& shape = context->input(2);
const Tensor& weights = context->input(3);
bool use_weights = weights.NumElements() > 0;
bool is_1d = shape.NumElements() == 1;
const auto indices_values = indices.matrix<int64>();
const auto values_values = values.flat<T>();
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
int num_values = values.NumElements();
auto per_batch_counts = BatchedIntMap(num_batches);
T max_value = 0;
for (int idx = 0; idx < num_values; ++idx) {
int batch = is_1d ? 0 : indices_values(idx, 0);
const auto& value = values_values(idx);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_count_) {
(per_batch_counts[batch])[value] = 1;
} else {
(per_batch_counts[batch])[value]++;
}
if (value > max_value) {
max_value = value;
}
}
}
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
if (use_weights) {
OP_REQUIRES_OK(context,
OutputWeightedSparse(per_batch_counts, num_output_values,
weights, is_1d, context));
} else {
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
is_1d, context));
}
}
private:
T minlength_;
T maxlength_;
bool binary_count_;
};
template <class T>
class RaggedCount : public OpKernel {
public:
explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
}
void Compute(OpKernelContext* context) override {
const Tensor& splits = context->input(0);
const Tensor& values = context->input(1);
const Tensor& weights = context->input(2);
bool use_weights = weights.NumElements() > 0;
const auto splits_values = splits.flat<int64>();
const auto values_values = values.flat<T>();
int num_batches = splits.NumElements() - 1;
int num_values = values.NumElements();
auto per_batch_counts = BatchedIntMap(num_batches);
T max_value = 0;
int batch_idx = 0;
for (int idx = 0; idx < num_values; ++idx) {
while (idx >= splits_values(batch_idx)) {
batch_idx++;
}
const auto& value = values_values(idx);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_count_) {
(per_batch_counts[batch_idx - 1])[value] = 1;
} else {
(per_batch_counts[batch_idx - 1])[value]++;
}
if (value > max_value) {
max_value = value;
}
}
}
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
if (use_weights) {
OP_REQUIRES_OK(context,
OutputWeightedSparse(per_batch_counts, num_output_values,
weights, false, context));
} else {
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
false, context));
}
}
private:
T minlength_;
T maxlength_;
bool binary_count_;
};
#define REGISTER(TYPE) \
\
REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \
.TypeConstraint<TYPE>("T") \
.Device(DEVICE_CPU), \
DenseCount<TYPE>) \
\
REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \
.TypeConstraint<TYPE>("T") \
.Device(DEVICE_CPU), \
SparseCount<TYPE>) \
\
REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \
.TypeConstraint<TYPE>("T") \
.Device(DEVICE_CPU), \
RaggedCount<TYPE>)
REGISTER(int32);
REGISTER(int64);
#undef REGISTER
} // namespace tensorflow

View File

@ -0,0 +1,47 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST_F(OpsTestBase, DenseCountSparseOutputShapeFn) {
ShapeInferenceTestOp op("DenseCountSparseOutput");
INFER_OK(op, "[?];?", "[?,1];[?];[1]");
INFER_OK(op, "[?,?];?", "[?,2];[?];[2]");
}
TEST_F(OpsTestBase, SparseCountSparseOutputShapeFn) {
ShapeInferenceTestOp op("SparseCountSparseOutput");
INFER_OK(op, "[?,1];?;?;?", "[?,d0_1];[?];[d0_1]");
INFER_OK(op, "[?,2];?;?;?", "[?,d0_1];[?];[d0_1]");
}
TEST_F(OpsTestBase, RaggedCountSparseOutputShapeFn) {
ShapeInferenceTestOp op("RaggedCountSparseOutput");
INFER_OK(op, "?;[?];?", "[?,2];[?];[2]");
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,97 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
Status DenseCountSparseOutputShapeFn(InferenceContext *c) {
int32 rank = c->Rank(c->input(0));
DimensionHandle nvals = c->UnknownDim();
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
c->set_output(1, c->Vector(nvals)); // out.values
c->set_output(2, c->Vector(rank)); // out.dense_shape
return Status::OK();
}
Status SparseCountSparseOutputShapeFn(InferenceContext *c) {
DimensionHandle rank = c->Dim(c->input(0), 1);
DimensionHandle nvals = c->UnknownDim();
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
c->set_output(1, c->Vector(nvals)); // out.values
c->set_output(2, c->Vector(rank)); // out.dense_shape
return Status::OK();
}
Status RaggedCountSparseOutputShapeFn(InferenceContext *c) {
int32 rank = c->Rank(c->input(1));
if (rank != c->kUnknownRank) {
++rank; // Add the ragged dimension
}
DimensionHandle nvals = c->UnknownDim();
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
c->set_output(1, c->Vector(nvals)); // out.values
c->set_output(2, c->Vector(rank)); // out.dense_shape
return Status::OK();
}
REGISTER_OP("DenseCountSparseOutput")
.Input("values: T")
.Input("weights: float")
.Attr("T: {int32, int64}")
.Attr("minlength: int >= -1 = -1")
.Attr("maxlength: int >= -1 = -1")
.Attr("binary_count: bool")
.Attr("output_type: {int64, float}")
.SetShapeFn(DenseCountSparseOutputShapeFn)
.Output("output_indices: int64")
.Output("output_values: output_type")
.Output("output_dense_shape: int64");
REGISTER_OP("SparseCountSparseOutput")
.Input("indices: int64")
.Input("values: T")
.Input("dense_shape: int64")
.Input("weights: float")
.Attr("T: {int32, int64}")
.Attr("minlength: int >= -1 = -1")
.Attr("maxlength: int >= -1 = -1")
.Attr("binary_count: bool")
.Attr("output_type: {int64, float}")
.SetShapeFn(SparseCountSparseOutputShapeFn)
.Output("output_indices: int64")
.Output("output_values: output_type")
.Output("output_dense_shape: int64");
REGISTER_OP("RaggedCountSparseOutput")
.Input("splits: int64")
.Input("values: T")
.Input("weights: float")
.Attr("T: {int32, int64}")
.Attr("minlength: int >= -1 = -1")
.Attr("maxlength: int >= -1 = -1")
.Attr("binary_count: bool")
.Attr("output_type: {int64, float}")
.SetShapeFn(RaggedCountSparseOutputShapeFn)
.Output("output_indices: int64")
.Output("output_values: output_type")
.Output("output_dense_shape: int64");
} // namespace tensorflow

View File

@ -1,4 +1,3 @@
# Python support for TensorFlow.
#
# Public targets:
# ":platform" - Low-level and platform-specific Python code.
@ -135,6 +134,7 @@ py_library(
":_pywrap_utils",
":array_ops",
":audio_ops_gen",
":bincount",
":bitwise_ops",
":boosted_trees_ops",
":check_ops",
@ -2898,6 +2898,11 @@ tf_gen_op_wrapper_private_py(
],
)
tf_gen_op_wrapper_private_py(
name = "count_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
)
tf_gen_op_wrapper_private_py(
name = "parsing_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
@ -3463,6 +3468,28 @@ py_library(
],
)
py_library(
name = "bincount",
srcs = ["ops/bincount.py"],
srcs_version = "PY2AND3",
deps = [
":count_ops_gen",
":framework",
":framework_for_generated_wrappers",
],
)
tf_py_test(
name = "bincount_test",
size = "small",
srcs = ["ops/bincount_test.py"],
python_version = "PY3",
deps = [
":bincount",
":platform_test",
],
)
py_library(
name = "ctc_ops",
srcs = ["ops/ctc_ops.py"],

View File

@ -85,6 +85,7 @@ from tensorflow.python import keras
from tensorflow.python.feature_column import feature_column_lib as feature_column
from tensorflow.python.layers import layers
from tensorflow.python.module import module
from tensorflow.python.ops import bincount
from tensorflow.python.ops import bitwise_ops as bitwise
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import image_ops as image

View File

@ -0,0 +1,199 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# maxlengthations under the License.
# ==============================================================================
"""tf.sparse.bincount ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_count_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util.tf_export import tf_export
@tf_export("sparse.bincount")
def sparse_bincount(values,
weights=None,
axis=0,
minlength=None,
maxlength=None,
binary_count=False,
name=None):
"""Count the number of times an integer value appears in a tensor.
This op takes an N-dimensional `Tensor`, `RaggedTensor`, or `SparseTensor`,
and returns an N-dimensional int64 SparseTensor where element
`[i0...i[axis], j]` contains the number of times the value `j` appears in
slice `[i0...i[axis], :]` of the input tensor. Currently, only N=0 and
N=-1 are supported.
Args:
values: A Tensor, RaggedTensor, or SparseTensor whose values should be
counted. These tensors must have a rank of 1 or 2.
weights: A 1-dimensional Tensor of weights. If specified, the input array is
weighted by the weight array, i.e. if a value `n` is found at position
`i`, `out[n]` will be increased by `weight[i]` instead of 1.
axis: The axis to slice over. Axes at and below `axis` will be flattened
before bin counting. Currently, only `0`, and `-1` are supported. If None,
all axes will be flattened (identical to passing `0`).
minlength: If given, skips `values` that are less than `minlength`, and
ensures that the output has a `dense_shape` of at least `minlength` in the
inner dimension.
maxlength: If given, skips `values` that are greater than or equal to
`maxlength`, and ensures that the output has a `dense_shape` of at most
`maxlength` in the inner dimension.
binary_count: Whether to do a binary count. When True, this op will return 1
for any value that exists instead of counting the number of occurrences.
name: A name for this op.
Returns:
A SparseTensor with `output.shape = values.shape[:axis] + [N]`, where `N` is
* `maxlength` (if set);
* `minlength` (if set, and `minlength > reduce_max(values)`);
* `0` (if `values` is empty);
* `reduce_max(values) + 1` otherwise.
Examples:
**Bin-counting every item in individual batches**
This example takes an input (which could be a Tensor, RaggedTensor, or
SparseTensor) and returns a SparseTensor where the value of (i,j) is the
number of times value j appears in batch i.
>>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]]
>>> output = tf.sparse.bincount(data, axis=-1)
>>> print(output)
SparseTensor(indices=tf.Tensor(
[[ 0 10]
[ 0 20]
[ 0 30]
[ 1 11]
[ 1 101]
[ 1 10001]], shape=(6, 2), dtype=int64),
values=tf.Tensor([1 2 1 2 1 1], shape=(6,), dtype=int64),
dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64))
**Bin-counting with defined output shape**
This example takes an input (which could be a Tensor, RaggedTensor, or
SparseTensor) and returns a SparseTensor where the value of (i,j) is the
number of times value j appears in batch i. However, all values of j
above 'maxlength' are ignored. The dense_shape of the output sparse tensor
is set to 'minlength'. Note that, while the input is identical to the
example above, the value '10001' in batch item 2 is dropped, and the
dense shape is [2, 500] instead of [2,10002] or [2, 102].
>>> minlength = maxlength = 500
>>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]]
>>> output = tf.sparse.bincount(
... data, axis=-1, minlength=minlength, maxlength=maxlength)
>>> print(output)
SparseTensor(indices=tf.Tensor(
[[ 0 10]
[ 0 20]
[ 0 30]
[ 1 11]
[ 1 101]], shape=(5, 2), dtype=int64),
values=tf.Tensor([1 2 1 2 1], shape=(5,), dtype=int64),
dense_shape=tf.Tensor([ 2 500], shape=(2,), dtype=int64))
**Binary bin-counting**
This example takes an input (which could be a Tensor, RaggedTensor, or
SparseTensor) and returns a SparseTensor where (i,j) is 1 if the value j
appears in batch i at least once and is 0 otherwise. Note that, even though
some values (like 20 in batch 1 and 11 in batch 2) appear more than once,
the 'values' tensor is all 1s.
>>> dense = [[10, 20, 30, 20], [11, 101, 11, 10001]]
>>> output = tf.sparse.bincount(dense, binary_count=True, axis=-1)
>>> print(output)
SparseTensor(indices=tf.Tensor(
[[ 0 10]
[ 0 20]
[ 0 30]
[ 1 11]
[ 1 101]
[ 1 10001]], shape=(6, 2), dtype=int64),
values=tf.Tensor([1 1 1 1 1 1], shape=(6,), dtype=int64),
dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64))
"""
with ops.name_scope(name, "count", [values, weights]):
if not isinstance(values, sparse_tensor.SparseTensor):
values = ragged_tensor.convert_to_tensor_or_ragged_tensor(
values, name="values")
if weights is not None and binary_count:
raise ValueError("binary_count and weights are mutually exclusive.")
if weights is None:
weights = []
output_type = dtypes.int64
else:
output_type = dtypes.float32
if axis is None:
axis = 0
if axis not in [0, -1]:
raise ValueError("Unsupported axis value %s. Only 0 and -1 are currently "
"supported." % axis)
minlength_value = minlength if minlength is not None else -1
maxlength_value = maxlength if maxlength is not None else -1
if axis == 0:
if isinstance(values,
(sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
values = values.values
else:
values = array_ops.reshape(values, [-1])
if isinstance(values, sparse_tensor.SparseTensor):
c_ind, c_val, c_shape = gen_count_ops.sparse_count_sparse_output(
values.indices,
values.values,
values.dense_shape,
weights=weights,
minlength=minlength_value,
maxlength=maxlength_value,
binary_count=binary_count,
output_type=output_type)
elif isinstance(values, ragged_tensor.RaggedTensor):
c_ind, c_val, c_shape = gen_count_ops.ragged_count_sparse_output(
values.row_splits,
values.values,
weights=weights,
minlength=minlength_value,
maxlength=maxlength_value,
binary_count=binary_count,
output_type=output_type)
else:
c_ind, c_val, c_shape = gen_count_ops.dense_count_sparse_output(
values,
weights=weights,
minlength=minlength_value,
maxlength=maxlength_value,
binary_count=binary_count,
output_type=output_type)
return sparse_tensor.SparseTensor(c_ind, c_val, c_shape)

View File

@ -0,0 +1,504 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# maxlengthations under the License.
# ==============================================================================
"""Tests for bincount ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.ops import bincount
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
class TestSparseCount(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
{
"testcase_name": "_no_maxlength",
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
"expected_values": [1, 1, 1, 2, 1],
"expected_shape": [2, 6]
}, {
"testcase_name": "_maxlength",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"maxlength": 7,
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
"expected_values": [1, 1, 1, 1, 2],
"expected_shape": [2, 7]
}, {
"testcase_name": "_minlength",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 9,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [1, 1, 1, 1, 1, 2, 1],
"expected_shape": [2, 9]
}, {
"testcase_name": "_minlength_larger_values",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 3,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [1, 1, 1, 1, 1, 2, 1],
"expected_shape": [2, 8]
}, {
"testcase_name": "_no_maxlength_binary",
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
"expected_values": [1, 1, 1, 1, 1],
"expected_shape": [2, 6],
"binary_count": True,
}, {
"testcase_name": "_maxlength_binary",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"maxlength": 7,
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
"expected_values": [1, 1, 1, 1, 1],
"expected_shape": [2, 7],
"binary_count": True,
}, {
"testcase_name": "_minlength_binary",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 9,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
"expected_shape": [2, 9],
"binary_count": True,
}, {
"testcase_name": "_minlength_larger_values_binary",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 3,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
"expected_shape": [2, 8],
"binary_count": True,
}, {
"testcase_name": "_no_maxlength_weights",
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
"expected_values": [1, 2, 3, 8, 5],
"expected_shape": [2, 6],
"weights": [0.5, 1, 2, 3, 4, 5]
}, {
"testcase_name": "_maxlength_weights",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"maxlength": 7,
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
"expected_values": [1, 2, 3, 0.5, 8],
"expected_shape": [2, 7],
"weights": [0.5, 1, 2, 3, 4, 5, 6]
}, {
"testcase_name": "_minlength_weights",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 9,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [1, 2, 3, 7, 0.5, 8, 7],
"expected_shape": [2, 9],
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
}, {
"testcase_name": "_minlength_larger_values_weights",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 3,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [1, 2, 3, 7, 0.5, 8, 7],
"expected_shape": [2, 8],
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
}, {
"testcase_name": "_1d",
"x": np.array([3, 2, 1, 1], dtype=np.int32),
"expected_indices": [[1], [2], [3]],
"expected_values": [2, 1, 1],
"expected_shape": [4]
}, {
"testcase_name": "_all_axes",
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
"expected_indices": [[1], [2], [3], [4], [5]],
"expected_values": [1, 1, 1, 2, 1],
"expected_shape": [6],
"axis": None
})
def test_dense_input(self,
x,
expected_indices,
expected_values,
expected_shape,
minlength=None,
maxlength=None,
binary_count=False,
weights=None,
axis=-1):
y = bincount.sparse_bincount(
x,
weights=weights,
minlength=minlength,
maxlength=maxlength,
binary_count=binary_count,
axis=axis)
self.assertAllEqual(expected_indices, y.indices)
self.assertAllEqual(expected_values, y.values)
self.assertAllEqual(expected_shape, y.dense_shape)
@parameterized.named_parameters(
{
"testcase_name":
"_no_maxlength",
"x":
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [1, 1, 2, 1],
"expected_shape": [3, 6],
},
{
"testcase_name":
"_maxlength",
"x":
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [1, 1, 2, 1],
"expected_shape": [3, 7],
"maxlength":
7,
},
{
"testcase_name":
"_minlength",
"x":
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [1, 1, 1, 2, 1],
"expected_shape": [3, 9],
"minlength":
9,
},
{
"testcase_name":
"_minlength_larger_values",
"x":
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [1, 1, 1, 2, 1],
"expected_shape": [3, 8],
"minlength":
3,
},
{
"testcase_name":
"_no_maxlength_binary",
"x":
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [1, 1, 1, 1],
"expected_shape": [3, 6],
"binary_count":
True,
},
{
"testcase_name":
"_maxlength_binary",
"x":
np.array([[3, 0, 1, 0], [0, 0, 7, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [1, 1, 1, 1],
"expected_shape": [3, 7],
"maxlength":
7,
"binary_count":
True,
},
{
"testcase_name":
"_minlength_binary",
"x":
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [1, 1, 1, 1, 1],
"expected_shape": [3, 9],
"minlength":
9,
"binary_count":
True,
},
{
"testcase_name":
"_minlength_larger_values_binary",
"x":
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [1, 1, 1, 1, 1],
"expected_shape": [3, 8],
"minlength":
3,
"binary_count":
True,
},
{
"testcase_name":
"_no_maxlength_weights",
"x":
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [1, 3, 8, 5],
"expected_shape": [3, 6],
"weights": [0.5, 1, 2, 3, 4, 5]
},
{
"testcase_name":
"_maxlength_weights",
"x":
np.array([[3, 0, 1, 0], [0, 0, 7, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [1, 3, 8, 5],
"expected_shape": [3, 7],
"maxlength":
7,
"weights": [0.5, 1, 2, 3, 4, 5, 6]
},
{
"testcase_name":
"_minlength_weights",
"x":
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [1, 3, 7, 8, 5],
"expected_shape": [3, 9],
"minlength":
9,
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name":
"_minlength_larger_values_weights",
"x":
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [1, 3, 7, 8, 5],
"expected_shape": [3, 8],
"minlength":
3,
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name": "_1d",
"x": np.array([3, 0, 1, 1], dtype=np.int32),
"expected_indices": [[1], [3]],
"expected_values": [2, 1],
"expected_shape": [4],
},
{
"testcase_name":
"_all_axes",
"x":
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[1], [3], [4], [5]],
"expected_values": [1, 1, 2, 1],
"expected_shape": [6],
"axis":
None,
},
)
def test_sparse_input(self,
x,
expected_indices,
expected_values,
expected_shape,
maxlength=None,
minlength=None,
binary_count=False,
weights=None,
axis=-1):
x_sparse = sparse_ops.from_dense(x)
y = bincount.sparse_bincount(
x_sparse,
weights=weights,
minlength=minlength,
maxlength=maxlength,
binary_count=binary_count,
axis=axis)
self.assertAllEqual(expected_indices, y.indices)
self.assertAllEqual(expected_values, y.values)
self.assertAllEqual(expected_shape, y.dense_shape)
@parameterized.named_parameters(
{
"testcase_name": "_no_maxlength",
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [1, 1, 1, 1, 2, 1],
"expected_shape": [5, 6],
},
{
"testcase_name": "_maxlength",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"maxlength": 7,
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [1, 1, 1, 1, 2, 1],
"expected_shape": [5, 7],
},
{
"testcase_name": "_minlength",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"minlength": 9,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [1, 1, 1, 1, 1, 2, 1],
"expected_shape": [5, 9],
},
{
"testcase_name": "_minlength_larger_values",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"minlength": 3,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [1, 1, 1, 1, 1, 2, 1],
"expected_shape": [5, 8],
},
{
"testcase_name": "_no_maxlength_binary",
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1],
"expected_shape": [5, 6],
"binary_count": True,
},
{
"testcase_name": "_maxlength_binary",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"maxlength": 7,
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1],
"expected_shape": [5, 7],
"binary_count": True,
},
{
"testcase_name": "_minlength_binary",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"minlength": 9,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
"expected_shape": [5, 9],
"binary_count": True,
},
{
"testcase_name": "_minlength_larger_values_binary",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"minlength": 3,
"binary_count": True,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
"expected_shape": [5, 8],
},
{
"testcase_name": "_no_maxlength_weights",
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [0.5, 1, 3, 0.5, 8, 5],
"expected_shape": [5, 6],
"weights": [0.5, 1, 2, 3, 4, 5]
},
{
"testcase_name": "_maxlength_weights",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"maxlength": 7,
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [0.5, 1, 3, 0.5, 8, 5],
"expected_shape": [5, 7],
"weights": [0.5, 1, 2, 3, 4, 5, 6]
},
{
"testcase_name": "_minlength_weights",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"minlength": 9,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [0.5, 1, 3, 7, 0.5, 8, 5],
"expected_shape": [5, 9],
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name": "_minlength_larger_values_weights",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"minlength": 3,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [0.5, 1, 3, 7, 0.5, 8, 5],
"expected_shape": [5, 8],
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name": "_1d",
"x": [3, 0, 1, 1],
"expected_indices": [[0], [1], [3]],
"expected_values": [1, 2, 1],
"expected_shape": [4],
},
{
"testcase_name": "_all_axes",
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
"expected_indices": [[0], [1], [3], [4], [5]],
"expected_values": [2, 1, 1, 2, 1],
"expected_shape": [6],
"axis": None,
},
)
def test_ragged_input(self,
x,
expected_indices,
expected_values,
expected_shape,
maxlength=None,
minlength=None,
binary_count=False,
weights=None,
axis=-1):
x_ragged = ragged_factory_ops.constant(x)
y = bincount.sparse_bincount(
x_ragged,
weights=weights,
minlength=minlength,
maxlength=maxlength,
binary_count=binary_count,
axis=axis)
self.assertAllEqual(expected_indices, y.indices)
self.assertAllEqual(expected_values, y.values)
self.assertAllEqual(expected_shape, y.dense_shape)
if __name__ == "__main__":
test.main()

View File

@ -1076,6 +1076,10 @@ tf_module {
name: "DeleteSessionTensor"
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DenseCountSparseOutput"
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "DenseToCSRSparseMatrix"
argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -3060,6 +3064,10 @@ tf_module {
name: "RGBToHSV"
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedCountSparseOutput"
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "RaggedCross"
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4072,6 +4080,10 @@ tf_module {
name: "SparseConditionalAccumulator"
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
}
member_method {
name: "SparseCountSparseOutput"
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "SparseCross"
argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -12,6 +12,10 @@ tf_module {
name: "add"
argspec: "args=[\'a\', \'b\', \'threshold\', \'thresh\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "bincount"
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "concat"
argspec: "args=[\'axis\', \'sp_inputs\', \'name\', \'expand_nonconcat_dim\', \'concat_dim\', \'expand_nonconcat_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
@ -38,7 +42,7 @@ tf_module {
}
member_method {
name: "from_dense"
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "mask"

View File

@ -1076,6 +1076,10 @@ tf_module {
name: "DeleteSessionTensor"
argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DenseCountSparseOutput"
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "DenseToCSRSparseMatrix"
argspec: "args=[\'dense_input\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -3060,6 +3064,10 @@ tf_module {
name: "RGBToHSV"
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RaggedCountSparseOutput"
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "RaggedCross"
argspec: "args=[\'ragged_values\', \'ragged_row_splits\', \'sparse_indices\', \'sparse_values\', \'sparse_shape\', \'dense_inputs\', \'input_order\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_values_type\', \'out_row_splits_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4072,6 +4080,10 @@ tf_module {
name: "SparseConditionalAccumulator"
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'reduction_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'MEAN\', \'None\'], "
}
member_method {
name: "SparseCountSparseOutput"
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "SparseCross"
argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -8,6 +8,10 @@ tf_module {
name: "add"
argspec: "args=[\'a\', \'b\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
member_method {
name: "bincount"
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "concat"
argspec: "args=[\'axis\', \'sp_inputs\', \'expand_nonconcat_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
@ -34,7 +38,7 @@ tf_module {
}
member_method {
name: "from_dense"
argspec: "args=[\'tensor\', 'name'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "mask"