Switch weights from per-value to per-input-item.

PiperOrigin-RevId: 311477582
Change-Id: I749c4edfcfd4dd3acd036a1d14b2c493b8d8bfc8
This commit is contained in:
A. Unique TensorFlower 2020-05-13 23:31:44 -07:00 committed by TensorFlower Gardener
parent e40aeb534e
commit 112288586d
11 changed files with 278 additions and 441 deletions

View File

@ -4,62 +4,61 @@ op {
in_arg {
name: "values"
description: <<END
Tensor containing data to count.
int32 or int64; Tensor containing data to count.
END
}
in_arg {
name: "weights"
description: <<END
A Tensor of the same shape as indices containing per-index weight values. May
also be the empty tensor if no weights are used.
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
END
}
out_arg {
name: "output_indices"
description: <<END
Indices tensor for the resulting sparse tensor object.
int64; indices tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_values"
description: <<END
Values tensor for the resulting sparse tensor object.
int64 or float32; values tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_dense_shape"
description: <<END
Shape tensor for the resulting sparse tensor object.
int64; shape tensor for the resulting sparse tensor object.
END
}
attr {
name: "T"
description: <<END
Dtype of the input values tensor.
dtype; dtype of the input values tensor.
END
}
attr {
name: "minlength"
description: <<END
Minimum value to count. Can be set to -1 for no minimum.
int32; minimum value to count. Can be set to -1 for no minimum.
END
}
attr {
name: "maxlength"
description: <<END
Maximum value to count. Can be set to -1 for no maximum.
int32; maximum value to count. Can be set to -1 for no maximum.
END
}
attr {
name: "binary_output"
name: "binary_count"
description: <<END
Whether to output the number of occurrences of each value or 1.
bool; whether to output the number of occurrences of each value or 1.
END
}
attr {
name: "output_type"
description: <<END
Dtype of the output values tensor.
dtype; dtype of the output values tensor.
END
}
summary: "Performs sparse-output bin counting for a tf.tensor input."

View File

@ -4,68 +4,67 @@ op {
in_arg {
name: "splits"
description: <<END
Tensor containing the row splits of the ragged tensor to count.
int64; Tensor containing the row splits of the ragged tensor to count.
END
}
in_arg {
name: "values"
description: <<END
Tensor containing values of the sparse tensor to count.
int32 or int64; Tensor containing values of the sparse tensor to count.
END
}
in_arg {
name: "weights"
description: <<END
A Tensor of the same shape as indices containing per-index weight values.
May also be the empty tensor if no weights are used.
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
END
}
out_arg {
name: "output_indices"
description: <<END
Indices tensor for the resulting sparse tensor object.
int64; indices tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_values"
description: <<END
Values tensor for the resulting sparse tensor object.
int64 or float32; values tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_dense_shape"
description: <<END
Shape tensor for the resulting sparse tensor object.
int64; shape tensor for the resulting sparse tensor object.
END
}
attr {
name: "T"
description: <<END
Dtype of the input values tensor.
dtype; dtype of the input values tensor.
END
}
attr {
name: "minlength"
description: <<END
Minimum value to count. Can be set to -1 for no minimum.
int32; minimum value to count. Can be set to -1 for no minimum.
END
}
attr {
name: "maxlength"
description: <<END
Maximum value to count. Can be set to -1 for no maximum.
int32; maximum value to count. Can be set to -1 for no maximum.
END
}
attr {
name: "binary_output"
name: "binary_count"
description: <<END
Whether to output the number of occurrences of each value or 1.
bool; whether to output the number of occurrences of each value or 1.
END
}
attr {
name: "output_type"
description: <<END
Dtype of the output values tensor.
dtype; dtype of the output values tensor.
END
}
summary: "Performs sparse-output bin counting for a ragged tensor input."

View File

@ -4,74 +4,73 @@ op {
in_arg {
name: "indices"
description: <<END
Tensor containing the indices of the sparse tensor to count.
int64; Tensor containing the indices of the sparse tensor to count.
END
}
in_arg {
name: "values"
description: <<END
Tensor containing values of the sparse tensor to count.
int32 or int64; Tensor containing values of the sparse tensor to count.
END
}
in_arg {
name: "dense_shape"
description: <<END
Tensor containing the dense shape of the sparse tensor to count.
int64; Tensor containing the dense shape of the sparse tensor to count.
END
}
in_arg {
name: "weights"
description: <<END
A Tensor of the same shape as indices containing per-index weight values.
May also be the empty tensor if no weights are used.
float32; Optional rank 1 Tensor (shape=[max_values]) with weights for each count value.
END
}
out_arg {
name: "output_indices"
description: <<END
Indices tensor for the resulting sparse tensor object.
int64; indices tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_values"
description: <<END
Values tensor for the resulting sparse tensor object.
int64 or float32; values tensor for the resulting sparse tensor object.
END
}
out_arg {
name: "output_dense_shape"
description: <<END
Shape tensor for the resulting sparse tensor object.
int64; shape tensor for the resulting sparse tensor object.
END
}
attr {
name: "T"
description: <<END
Dtype of the input values tensor.
dtype; dtype of the input values tensor.
END
}
attr {
name: "minlength"
description: <<END
Minimum value to count. Can be set to -1 for no minimum.
int32; minimum value to count. Can be set to -1 for no minimum.
END
}
attr {
name: "maxlength"
description: <<END
Maximum value to count. Can be set to -1 for no maximum.
int32; maximum value to count. Can be set to -1 for no maximum.
END
}
attr {
name: "binary_output"
name: "binary_count"
description: <<END
Whether to output the number of occurrences of each value or 1.
bool; whether to output the number of occurrences of each value or 1.
END
}
attr {
name: "output_type"
description: <<END
Dtype of the output values tensor.
dtype; dtype of the output values tensor.
END
}
summary: "Performs sparse-output bin counting for a sparse tensor input."

View File

@ -16,20 +16,17 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
template <class T>
using BatchedMap = std::vector<absl::flat_hash_map<int64, T>>;
using BatchedIntMap = std::vector<absl::flat_hash_map<int64, int64>>;
namespace {
// TODO(momernick): Extend this function to work with outputs of rank > 2.
template <class T>
Status OutputSparse(const BatchedMap<T>& per_batch_counts, int num_values,
Status OutputSparse(const BatchedIntMap& per_batch_counts, int num_values,
bool is_1d, OpKernelContext* context) {
int total_values = 0;
int num_batches = per_batch_counts.size();
@ -47,11 +44,11 @@ Status OutputSparse(const BatchedMap<T>& per_batch_counts, int num_values,
context->allocate_output(1, TensorShape({total_values}), &values));
auto output_indices = indices->matrix<int64>();
auto output_values = values->flat<T>();
auto output_values = values->flat<int64>();
int64 value_loc = 0;
for (int b = 0; b < num_batches; ++b) {
const auto& per_batch_count = per_batch_counts[b];
std::vector<std::pair<int, T>> pairs(per_batch_count.begin(),
std::vector<std::pair<int, int>> pairs(per_batch_count.begin(),
per_batch_count.end());
std::sort(pairs.begin(), pairs.end());
for (const auto& x : pairs) {
@ -80,19 +77,85 @@ Status OutputSparse(const BatchedMap<T>& per_batch_counts, int num_values,
return Status::OK();
}
int GetOutputSize(int max_seen, int max_length, int min_length) {
Status OutputWeightedSparse(const BatchedIntMap& per_batch_counts,
int num_values, const Tensor& weights, bool is_1d,
OpKernelContext* context) {
if (!TensorShapeUtils::IsVector(weights.shape())) {
return errors::InvalidArgument(
"Weights must be a 1-dimensional tensor. Got: ",
weights.shape().DebugString());
}
if (num_values > weights.dim_size(0)) {
return errors::InvalidArgument("The maximum array value was ", num_values,
", but the weight array has size ",
weights.shape().DebugString());
}
auto weight_values = weights.flat<float>();
int total_values = 0;
int num_batches = per_batch_counts.size();
for (const auto& per_batch_count : per_batch_counts) {
total_values += per_batch_count.size();
}
Tensor* indices;
int inner_dim = is_1d ? 1 : 2;
TF_RETURN_IF_ERROR(context->allocate_output(
0, TensorShape({total_values, inner_dim}), &indices));
Tensor* values;
TF_RETURN_IF_ERROR(
context->allocate_output(1, TensorShape({total_values}), &values));
auto output_indices = indices->matrix<int64>();
auto output_values = values->flat<float>();
int64 value_loc = 0;
for (int b = 0; b < num_batches; ++b) {
const auto& per_batch_count = per_batch_counts[b];
std::vector<std::pair<int, int>> pairs(per_batch_count.begin(),
per_batch_count.end());
std::sort(pairs.begin(), pairs.end());
for (const auto& x : pairs) {
if (is_1d) {
output_indices(value_loc, 0) = x.first;
} else {
output_indices(value_loc, 0) = b;
output_indices(value_loc, 1) = x.first;
}
output_values(value_loc) = x.second * weight_values(x.first);
++value_loc;
}
}
Tensor* dense_shape;
if (is_1d) {
TF_RETURN_IF_ERROR(
context->allocate_output(2, TensorShape({1}), &dense_shape));
dense_shape->flat<int64>().data()[0] = num_values;
} else {
TF_RETURN_IF_ERROR(
context->allocate_output(2, TensorShape({2}), &dense_shape));
dense_shape->flat<int64>().data()[0] = num_batches;
dense_shape->flat<int64>().data()[1] = num_values;
}
return Status::OK();
}
template <class T>
T GetOutputSize(T max_seen, T max_length, T min_length) {
return max_length > 0 ? max_length : std::max((max_seen + 1), min_length);
}
} // namespace
template <class T, class W>
template <class T>
class DenseCount : public OpKernel {
public:
explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
}
void Compute(OpKernelContext* context) override {
@ -107,15 +170,6 @@ class DenseCount : public OpKernel {
"Input must be a 1 or 2-dimensional tensor. Got: ",
data.shape().DebugString()));
if (use_weights) {
OP_REQUIRES(
context, weights.shape() == data.shape(),
errors::InvalidArgument(
"Weights and data must have the same shape. Weight shape: ",
weights.shape().DebugString(),
"; data shape: ", data.shape().DebugString()));
}
bool is_1d = TensorShapeUtils::IsVector(data.shape());
int negative_valued_axis = -1;
int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
@ -125,23 +179,19 @@ class DenseCount : public OpKernel {
num_batch_elements *= data.shape().dim_size(i);
}
int num_value_elements = data.shape().num_elements() / num_batch_elements;
auto per_batch_counts = BatchedMap<W>(num_batch_elements);
auto per_batch_counts = BatchedIntMap(num_batch_elements);
T max_value = 0;
const auto data_values = data.flat<T>();
const auto weight_values = weights.flat<W>();
int i = 0;
for (int b = 0; b < num_batch_elements; ++b) {
for (int v = 0; v < num_value_elements; ++v) {
const auto& value = data_values(i);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_output_) {
per_batch_counts[b][value] = 1;
} else if (use_weights) {
per_batch_counts[b][value] += weight_values(i);
if (binary_count_) {
(per_batch_counts[b])[value] = 1;
} else {
per_batch_counts[b][value]++;
(per_batch_counts[b])[value]++;
}
if (value > max_value) {
max_value = value;
@ -151,24 +201,30 @@ class DenseCount : public OpKernel {
}
}
int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
if (use_weights) {
OP_REQUIRES_OK(context,
OutputWeightedSparse(per_batch_counts, num_output_values,
weights, is_1d, context));
} else {
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
is_1d, context));
}
}
private:
int maxlength_;
int minlength_;
bool binary_output_;
T minlength_;
T maxlength_;
bool binary_count_;
};
template <class T, class W>
template <class T>
class SparseCount : public OpKernel {
public:
explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
}
void Compute(OpKernelContext* context) override {
@ -179,27 +235,23 @@ class SparseCount : public OpKernel {
bool use_weights = weights.NumElements() > 0;
bool is_1d = shape.NumElements() == 1;
const auto indices_values = indices.matrix<int64>();
const auto values_values = values.flat<T>();
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
int num_values = values.NumElements();
const auto indices_values = indices.matrix<int64>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();
auto per_batch_counts = BatchedMap<W>(num_batches);
auto per_batch_counts = BatchedIntMap(num_batches);
T max_value = 0;
for (int idx = 0; idx < num_values; ++idx) {
int batch = is_1d ? 0 : indices_values(idx, 0);
const auto& value = values_values(idx);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_output_) {
per_batch_counts[batch][value] = 1;
} else if (use_weights) {
per_batch_counts[batch][value] += weight_values(idx);
if (binary_count_) {
(per_batch_counts[batch])[value] = 1;
} else {
per_batch_counts[batch][value]++;
(per_batch_counts[batch])[value]++;
}
if (value > max_value) {
max_value = value;
@ -207,25 +259,30 @@ class SparseCount : public OpKernel {
}
}
int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
if (use_weights) {
OP_REQUIRES_OK(context,
OutputWeightedSparse(per_batch_counts, num_output_values,
weights, is_1d, context));
} else {
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
is_1d, context));
}
}
private:
int maxlength_;
int minlength_;
bool binary_output_;
bool validate_;
T minlength_;
T maxlength_;
bool binary_count_;
};
template <class T, class W>
template <class T>
class RaggedCount : public OpKernel {
public:
explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_));
}
void Compute(OpKernelContext* context) override {
@ -233,15 +290,13 @@ class RaggedCount : public OpKernel {
const Tensor& values = context->input(1);
const Tensor& weights = context->input(2);
bool use_weights = weights.NumElements() > 0;
bool is_1d = false;
const auto splits_values = splits.flat<int64>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();
int num_batches = splits.NumElements() - 1;
int num_values = values.NumElements();
auto per_batch_counts = BatchedMap<W>(num_batches);
auto per_batch_counts = BatchedIntMap(num_batches);
T max_value = 0;
int batch_idx = 0;
@ -251,12 +306,10 @@ class RaggedCount : public OpKernel {
}
const auto& value = values_values(idx);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_output_) {
per_batch_counts[batch_idx - 1][value] = 1;
} else if (use_weights) {
per_batch_counts[batch_idx - 1][value] += weight_values(idx);
if (binary_count_) {
(per_batch_counts[batch_idx - 1])[value] = 1;
} else {
per_batch_counts[batch_idx - 1][value]++;
(per_batch_counts[batch_idx - 1])[value]++;
}
if (value > max_value) {
max_value = value;
@ -264,47 +317,42 @@ class RaggedCount : public OpKernel {
}
}
int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
is_1d, context));
T num_output_values = GetOutputSize<T>(max_value, maxlength_, minlength_);
if (use_weights) {
OP_REQUIRES_OK(context,
OutputWeightedSparse(per_batch_counts, num_output_values,
weights, false, context));
} else {
OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values,
false, context));
}
}
private:
int maxlength_;
int minlength_;
bool binary_output_;
bool validate_;
T minlength_;
T maxlength_;
bool binary_count_;
};
#define REGISTER_W(W_TYPE) \
REGISTER(int32, W_TYPE) \
REGISTER(int64, W_TYPE)
#define REGISTER(I_TYPE, W_TYPE) \
#define REGISTER(TYPE) \
\
REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \
.TypeConstraint<I_TYPE>("T") \
.TypeConstraint<W_TYPE>("output_type") \
.TypeConstraint<TYPE>("T") \
.Device(DEVICE_CPU), \
DenseCount<I_TYPE, W_TYPE>) \
DenseCount<TYPE>) \
\
REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \
.TypeConstraint<I_TYPE>("T") \
.TypeConstraint<W_TYPE>("output_type") \
.TypeConstraint<TYPE>("T") \
.Device(DEVICE_CPU), \
SparseCount<I_TYPE, W_TYPE>) \
SparseCount<TYPE>) \
\
REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \
.TypeConstraint<I_TYPE>("T") \
.TypeConstraint<W_TYPE>("output_type") \
.TypeConstraint<TYPE>("T") \
.Device(DEVICE_CPU), \
RaggedCount<I_TYPE, W_TYPE>)
RaggedCount<TYPE>)
TF_CALL_INTEGRAL_TYPES(REGISTER_W);
TF_CALL_float(REGISTER_W);
TF_CALL_double(REGISTER_W);
#undef REGISTER_W
REGISTER(int32);
REGISTER(int64);
#undef REGISTER
} // namespace tensorflow

View File

@ -19,21 +19,12 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
Status DenseCountSparseOutputShapeFn(InferenceContext *c) {
auto values = c->input(0);
auto weights = c->input(1);
ShapeHandle output;
auto num_weights = c->NumElements(weights);
if (c->ValueKnown(num_weights) && c->Value(num_weights) == 0) {
output = values;
} else {
TF_RETURN_IF_ERROR(c->Merge(weights, values, &output));
}
auto rank = c->Rank(output);
auto nvals = c->UnknownDim();
int32 rank = c->Rank(c->input(0));
DimensionHandle nvals = c->UnknownDim();
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
c->set_output(1, c->Vector(nvals)); // out.values
c->set_output(2, c->Vector(rank)); // out.dense_shape
@ -41,8 +32,8 @@ Status DenseCountSparseOutputShapeFn(InferenceContext *c) {
}
Status SparseCountSparseOutputShapeFn(InferenceContext *c) {
auto rank = c->Dim(c->input(0), 1);
auto nvals = c->UnknownDim();
DimensionHandle rank = c->Dim(c->input(0), 1);
DimensionHandle nvals = c->UnknownDim();
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
c->set_output(1, c->Vector(nvals)); // out.values
c->set_output(2, c->Vector(rank)); // out.dense_shape
@ -54,7 +45,7 @@ Status RaggedCountSparseOutputShapeFn(InferenceContext *c) {
if (rank != c->kUnknownRank) {
++rank; // Add the ragged dimension
}
auto nvals = c->UnknownDim();
DimensionHandle nvals = c->UnknownDim();
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
c->set_output(1, c->Vector(nvals)); // out.values
c->set_output(2, c->Vector(rank)); // out.dense_shape
@ -63,12 +54,12 @@ Status RaggedCountSparseOutputShapeFn(InferenceContext *c) {
REGISTER_OP("DenseCountSparseOutput")
.Input("values: T")
.Input("weights: output_type")
.Input("weights: float")
.Attr("T: {int32, int64}")
.Attr("minlength: int >= -1 = -1")
.Attr("maxlength: int >= -1 = -1")
.Attr("binary_output: bool")
.Attr("output_type: {int32, int64, float, double}")
.Attr("binary_count: bool")
.Attr("output_type: {int64, float}")
.SetShapeFn(DenseCountSparseOutputShapeFn)
.Output("output_indices: int64")
.Output("output_values: output_type")
@ -78,12 +69,12 @@ REGISTER_OP("SparseCountSparseOutput")
.Input("indices: int64")
.Input("values: T")
.Input("dense_shape: int64")
.Input("weights: output_type")
.Input("weights: float")
.Attr("T: {int32, int64}")
.Attr("minlength: int >= -1 = -1")
.Attr("maxlength: int >= -1 = -1")
.Attr("binary_output: bool")
.Attr("output_type: {int32, int64, float, double}")
.Attr("binary_count: bool")
.Attr("output_type: {int64, float}")
.SetShapeFn(SparseCountSparseOutputShapeFn)
.Output("output_indices: int64")
.Output("output_values: output_type")
@ -92,12 +83,12 @@ REGISTER_OP("SparseCountSparseOutput")
REGISTER_OP("RaggedCountSparseOutput")
.Input("splits: int64")
.Input("values: T")
.Input("weights: output_type")
.Input("weights: float")
.Attr("T: {int32, int64}")
.Attr("minlength: int >= -1 = -1")
.Attr("maxlength: int >= -1 = -1")
.Attr("binary_output: bool")
.Attr("output_type: {int32, int64, float, double}")
.Attr("binary_count: bool")
.Attr("output_type: {int64, float}")
.SetShapeFn(RaggedCountSparseOutputShapeFn)
.Output("output_indices: int64")
.Output("output_values: output_type")

View File

@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import gen_count_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util.tf_export import tf_export
@ -33,7 +33,7 @@ def sparse_bincount(values,
axis=0,
minlength=None,
maxlength=None,
binary_output=False,
binary_count=False,
name=None):
"""Count the number of times an integer value appears in a tensor.
@ -58,9 +58,8 @@ def sparse_bincount(values,
maxlength: If given, skips `values` that are greater than or equal to
`maxlength`, and ensures that the output has a `dense_shape` of at most
`maxlength` in the inner dimension.
binary_output: If True, this op will output 1 instead of the number of times
a token appears (equivalent to one_hot + reduce_any instead of one_hot +
reduce_add). Defaults to False.
binary_count: Whether to do a binary count. When True, this op will return 1
for any value that exists instead of counting the number of occurrences.
name: A name for this op.
Returns:
@ -79,7 +78,7 @@ def sparse_bincount(values,
SparseTensor) and returns a SparseTensor where the value of (i,j) is the
number of times value j appears in batch i.
>>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64)
>>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]]
>>> output = tf.sparse.bincount(data, axis=-1)
>>> print(output)
SparseTensor(indices=tf.Tensor(
@ -103,7 +102,7 @@ def sparse_bincount(values,
dense shape is [2, 500] instead of [2,10002] or [2, 102].
>>> minlength = maxlength = 500
>>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64)
>>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]]
>>> output = tf.sparse.bincount(
... data, axis=-1, minlength=minlength, maxlength=maxlength)
>>> print(output)
@ -124,8 +123,8 @@ def sparse_bincount(values,
some values (like 20 in batch 1 and 11 in batch 2) appear more than once,
the 'values' tensor is all 1s.
>>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64)
>>> output = tf.sparse.bincount(data, binary_output=True, axis=-1)
>>> dense = [[10, 20, 30, 20], [11, 101, 11, 10001]]
>>> output = tf.sparse.bincount(dense, binary_count=True, axis=-1)
>>> print(output)
SparseTensor(indices=tf.Tensor(
[[ 0 10]
@ -137,42 +136,20 @@ def sparse_bincount(values,
values=tf.Tensor([1 1 1 1 1 1], shape=(6,), dtype=int64),
dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64))
**Weighted bin-counting**
This example takes two inputs - a values tensor and a weights tensor. These
tensors must be identically shaped, and have the same row splits or indices
in the case of RaggedTensors or SparseTensors. When performing a weighted
count, the op will output a SparseTensor where the value of (i, j) is the
sum of the values in the weight tensor's batch i in the locations where
the values tensor has the value j. In this case, the output dtype is the
same as the dtype of the weights tensor.
>>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64)
>>> weights = [[2, 0.25, 15, 0.5], [2, 17, 3, 0.9]]
>>> output = tf.sparse.bincount(data, weights=weights, axis=-1)
>>> print(output)
SparseTensor(indices=tf.Tensor(
[[ 0 10]
[ 0 20]
[ 0 30]
[ 1 11]
[ 1 101]
[ 1 10001]], shape=(6, 2), dtype=int64),
values=tf.Tensor([2. 0.75 15. 5. 17. 0.9], shape=(6,), dtype=float32),
dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64))
"""
with ops.name_scope(name, "count", [values, weights]):
if not isinstance(values, sparse_tensor.SparseTensor):
values = ragged_tensor.convert_to_tensor_or_ragged_tensor(
values, name="values")
if weights is not None:
if not isinstance(weights, sparse_tensor.SparseTensor):
weights = ragged_tensor.convert_to_tensor_or_ragged_tensor(
weights, name="weights")
if weights is not None and binary_output:
raise ValueError("binary_output and weights are mutually exclusive.")
if weights is not None and binary_count:
raise ValueError("binary_count and weights are mutually exclusive.")
if weights is None:
weights = []
output_type = dtypes.int64
else:
output_type = dtypes.float32
if axis is None:
axis = 0
@ -185,114 +162,38 @@ def sparse_bincount(values,
maxlength_value = maxlength if maxlength is not None else -1
if axis == 0:
if isinstance(values, sparse_tensor.SparseTensor):
if weights is not None:
weights = validate_sparse_weights(values, weights)
values = values.values
elif isinstance(values, ragged_tensor.RaggedTensor):
if weights is not None:
weights = validate_ragged_weights(values, weights)
if isinstance(values,
(sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
values = values.values
else:
if weights is not None:
weights = array_ops.reshape(weights, [-1])
values = array_ops.reshape(values, [-1])
if isinstance(values, sparse_tensor.SparseTensor):
weights = validate_sparse_weights(values, weights)
c_ind, c_val, c_shape = gen_count_ops.sparse_count_sparse_output(
values.indices,
values.values,
values.dense_shape,
weights,
weights=weights,
minlength=minlength_value,
maxlength=maxlength_value,
binary_output=binary_output)
binary_count=binary_count,
output_type=output_type)
elif isinstance(values, ragged_tensor.RaggedTensor):
weights = validate_ragged_weights(values, weights)
c_ind, c_val, c_shape = gen_count_ops.ragged_count_sparse_output(
values.row_splits,
values.values,
weights,
weights=weights,
minlength=minlength_value,
maxlength=maxlength_value,
binary_output=binary_output)
binary_count=binary_count,
output_type=output_type)
else:
weights = validate_dense_weights(values, weights)
c_ind, c_val, c_shape = gen_count_ops.dense_count_sparse_output(
values,
weights=weights,
minlength=minlength_value,
maxlength=maxlength_value,
binary_output=binary_output)
binary_count=binary_count,
output_type=output_type)
return sparse_tensor.SparseTensor(c_ind, c_val, c_shape)
def validate_dense_weights(values, weights):
"""Validates the passed weight tensor or creates an empty one."""
if weights is None:
return array_ops.constant([], dtype=values.dtype)
if not isinstance(weights, ops.Tensor):
raise ValueError(
"`weights` must be a tf.Tensor if `values` is a tf.Tensor.")
return weights
def validate_sparse_weights(values, weights):
"""Validates the passed weight tensor or creates an empty one."""
if weights is None:
return array_ops.constant([], dtype=values.values.dtype)
if not isinstance(weights, sparse_tensor.SparseTensor):
raise ValueError(
"`weights` must be a SparseTensor if `values` is a SparseTensor.")
checks = []
if weights.dense_shape is not values.dense_shape:
checks.append(
check_ops.assert_equal(
weights.dense_shape,
values.dense_shape,
message="'weights' and 'values' must have the same dense shape."))
if weights.indices is not values.indices:
checks.append(
check_ops.assert_equal(
weights.indices,
values.indices,
message="'weights' and 'values' must have the same indices.")
)
if checks:
with ops.control_dependencies(checks):
weights = array_ops.identity(weights.values)
else:
weights = weights.values
return weights
def validate_ragged_weights(values, weights):
"""Validates the passed weight tensor or creates an empty one."""
if weights is None:
return array_ops.constant([], dtype=values.values.dtype)
if not isinstance(weights, ragged_tensor.RaggedTensor):
raise ValueError(
"`weights` must be a RaggedTensor if `values` is a RaggedTensor.")
checks = []
if weights.row_splits is not values.row_splits:
checks.append(
check_ops.assert_equal(
weights.row_splits,
values.row_splits,
message="'weights' and 'values' must have the same row splits."))
if checks:
with ops.control_dependencies(checks):
weights = array_ops.identity(weights.values)
else:
weights = weights.values
return weights

View File

@ -21,8 +21,6 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.ops import bincount
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
@ -67,7 +65,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
"expected_values": [1, 1, 1, 1, 1],
"expected_shape": [2, 6],
"binary_output": True,
"binary_count": True,
}, {
"testcase_name": "_maxlength_binary",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
@ -75,7 +73,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
"expected_values": [1, 1, 1, 1, 1],
"expected_shape": [2, 7],
"binary_output": True,
"binary_count": True,
}, {
"testcase_name": "_minlength_binary",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
@ -84,7 +82,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
[1, 7]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
"expected_shape": [2, 9],
"binary_output": True,
"binary_count": True,
}, {
"testcase_name": "_minlength_larger_values_binary",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
@ -93,40 +91,40 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
[1, 7]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
"expected_shape": [2, 8],
"binary_output": True,
"binary_count": True,
}, {
"testcase_name": "_no_maxlength_weights",
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]],
"expected_values": [2, 1, 0.5, 9, 3],
"expected_values": [1, 2, 3, 8, 5],
"expected_shape": [2, 6],
"weights": [[0.5, 1, 2], [3, 4, 5]]
"weights": [0.5, 1, 2, 3, 4, 5]
}, {
"testcase_name": "_maxlength_weights",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"maxlength": 7,
"expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]],
"expected_values": [2, 1, 0.5, 3, 9],
"expected_values": [1, 2, 3, 0.5, 8],
"expected_shape": [2, 7],
"weights": [[0.5, 1, 2, 11], [7, 3, 4, 5]]
"weights": [0.5, 1, 2, 3, 4, 5, 6]
}, {
"testcase_name": "_minlength_weights",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 9,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [2, 1, 0.5, 3, 5, 13, 4],
"expected_values": [1, 2, 3, 7, 0.5, 8, 7],
"expected_shape": [2, 9],
"weights": [[0.5, 1, 2, 3], [4, 5, 6, 7]]
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
}, {
"testcase_name": "_minlength_larger_values_weights",
"x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
"minlength": 3,
"expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4],
[1, 7]],
"expected_values": [2, 1, 0.5, 3, 5, 13, 4],
"expected_values": [1, 2, 3, 7, 0.5, 8, 7],
"expected_shape": [2, 8],
"weights": [[0.5, 1, 2, 3], [4, 5, 6, 7]]
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
}, {
"testcase_name": "_1d",
"x": np.array([3, 2, 1, 1], dtype=np.int32),
@ -148,7 +146,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
expected_shape,
minlength=None,
maxlength=None,
binary_output=False,
binary_count=False,
weights=None,
axis=-1):
y = bincount.sparse_bincount(
@ -156,7 +154,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
weights=weights,
minlength=minlength,
maxlength=maxlength,
binary_output=binary_output,
binary_count=binary_count,
axis=axis)
self.assertAllEqual(expected_indices, y.indices)
self.assertAllEqual(expected_values, y.values)
@ -218,7 +216,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [1, 1, 1, 1],
"expected_shape": [3, 6],
"binary_output":
"binary_count":
True,
},
{
@ -232,7 +230,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_shape": [3, 7],
"maxlength":
7,
"binary_output":
"binary_count":
True,
},
{
@ -246,7 +244,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_shape": [3, 9],
"minlength":
9,
"binary_output":
"binary_count":
True,
},
{
@ -260,7 +258,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_shape": [3, 8],
"minlength":
3,
"binary_output":
"binary_count":
True,
},
{
@ -270,10 +268,9 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [2, 6, 7, 10],
"expected_values": [1, 3, 8, 5],
"expected_shape": [3, 6],
"weights":
np.array([[6, 0, 2, 0], [0, 0, 0, 0], [10, 0, 3.5, 3.5]]),
"weights": [0.5, 1, 2, 3, 4, 5]
},
{
"testcase_name":
@ -282,12 +279,11 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
np.array([[3, 0, 1, 0], [0, 0, 7, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]],
"expected_values": [2, 6, 7, 10],
"expected_values": [1, 3, 8, 5],
"expected_shape": [3, 7],
"maxlength":
7,
"weights":
np.array([[6, 0, 2, 0], [0, 0, 14, 0], [10, 0, 3.5, 3.5]]),
"weights": [0.5, 1, 2, 3, 4, 5, 6]
},
{
"testcase_name":
@ -296,12 +292,11 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [2, 6, 14, 6.5, 10],
"expected_values": [1, 3, 7, 8, 5],
"expected_shape": [3, 9],
"minlength":
9,
"weights":
np.array([[6, 0, 2, 0], [14, 0, 0, 0], [10, 0, 3, 3.5]]),
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name":
@ -310,12 +305,11 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]],
dtype=np.int32),
"expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]],
"expected_values": [2, 6, 14, 6.5, 10],
"expected_values": [1, 3, 7, 8, 5],
"expected_shape": [3, 8],
"minlength":
3,
"weights":
np.array([[6, 0, 2, 0], [14, 0, 0, 0], [10, 0, 3, 3.5]]),
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name": "_1d",
@ -344,17 +338,16 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
expected_shape,
maxlength=None,
minlength=None,
binary_output=False,
binary_count=False,
weights=None,
axis=-1):
x_sparse = sparse_ops.from_dense(x)
w_sparse = sparse_ops.from_dense(weights) if weights is not None else None
y = bincount.sparse_bincount(
x_sparse,
weights=w_sparse,
weights=weights,
minlength=minlength,
maxlength=maxlength,
binary_output=binary_output,
binary_count=binary_count,
axis=axis)
self.assertAllEqual(expected_indices, y.indices)
self.assertAllEqual(expected_values, y.values)
@ -400,7 +393,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1],
"expected_shape": [5, 6],
"binary_output": True,
"binary_count": True,
},
{
"testcase_name": "_maxlength_binary",
@ -409,7 +402,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1],
"expected_shape": [5, 7],
"binary_output": True,
"binary_count": True,
},
{
"testcase_name": "_minlength_binary",
@ -419,13 +412,13 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
[4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
"expected_shape": [5, 9],
"binary_output": True,
"binary_count": True,
},
{
"testcase_name": "_minlength_larger_values_binary",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"minlength": 3,
"binary_output": True,
"binary_count": True,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [1, 1, 1, 1, 1, 1, 1],
@ -435,18 +428,18 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"testcase_name": "_no_maxlength_weights",
"x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]],
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [0.5, 2, 6, 0.25, 8, 10],
"expected_values": [0.5, 1, 3, 0.5, 8, 5],
"expected_shape": [5, 6],
"weights": [[], [], [6, 0.5, 2], [], [10, 0.25, 5, 3]],
"weights": [0.5, 1, 2, 3, 4, 5]
},
{
"testcase_name": "_maxlength_weights",
"x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]],
"maxlength": 7,
"expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]],
"expected_values": [0.5, 2, 6, 0.25, 8, 10],
"expected_values": [0.5, 1, 3, 0.5, 8, 5],
"expected_shape": [5, 7],
"weights": [[], [], [6, 0.5, 2], [14], [10, 0.25, 5, 3]],
"weights": [0.5, 1, 2, 3, 4, 5, 6]
},
{
"testcase_name": "_minlength_weights",
@ -454,9 +447,9 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"minlength": 9,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [0.5, 2, 6, 14, 0.25, 8, 10],
"expected_values": [0.5, 1, 3, 7, 0.5, 8, 5],
"expected_shape": [5, 9],
"weights": [[], [], [6, 0.5, 2], [14], [10, 0.25, 5, 3]],
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name": "_minlength_larger_values_weights",
@ -464,9 +457,9 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
"minlength": 3,
"expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4],
[4, 5]],
"expected_values": [0.5, 2, 6, 14, 0.25, 8, 10],
"expected_values": [0.5, 1, 3, 7, 0.5, 8, 5],
"expected_shape": [5, 8],
"weights": [[], [], [6, 0.5, 2], [14], [10, 0.25, 5, 3]],
"weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8]
},
{
"testcase_name": "_1d",
@ -491,114 +484,21 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
expected_shape,
maxlength=None,
minlength=None,
binary_output=False,
binary_count=False,
weights=None,
axis=-1):
x_ragged = ragged_factory_ops.constant(x)
w = ragged_factory_ops.constant(weights) if weights is not None else None
y = bincount.sparse_bincount(
x_ragged,
weights=w,
weights=weights,
minlength=minlength,
maxlength=maxlength,
binary_output=binary_output,
binary_count=binary_count,
axis=axis)
self.assertAllEqual(expected_indices, y.indices)
self.assertAllEqual(expected_values, y.values)
self.assertAllEqual(expected_shape, y.dense_shape)
class TestSparseCountFailureModes(test.TestCase):
def test_dense_input_sparse_weights_fails(self):
x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
weights = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_dense_input_ragged_weights_fails(self):
x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]])
with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_dense_input_wrong_shape_fails(self):
x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
weights = np.array([[3, 2], [5, 4], [4, 3]])
# Note: Eager mode and graph mode throw different errors here. Graph mode
# will fail with a ValueError from the shape checking logic, while Eager
# will fail with an InvalidArgumentError from the kernel itself.
if context.executing_eagerly():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must have the same shape"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
else:
with self.assertRaisesRegexp(ValueError, "both shapes must be equal"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_sparse_input_dense_weights_fails(self):
x = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_sparse_input_ragged_weights_fails(self):
x = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]])
with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_sparse_input_wrong_indices_fails(self):
x = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
weights = sparse_ops.from_dense(
np.array([[3, 1, 0, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must have the same indices"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_sparse_input_too_many_indices_fails(self):
x = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
weights = sparse_ops.from_dense(
np.array([[3, 1, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Incompatible shapes"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_sparse_input_wrong_shape_fails(self):
x = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
weights = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4], [0, 0, 0, 0]],
dtype=np.int32))
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must have the same dense shape"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_ragged_input_dense_weights_fails(self):
x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]])
weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_ragged_input_sparse_weights_fails(self):
x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]])
weights = sparse_ops.from_dense(
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
def test_ragged_input_different_shape_fails(self):
x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]])
weights = ragged_factory_ops.constant([[6, 0.5, 2], [], [10, 0.25, 5, 3]])
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must have the same row splits"):
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
if __name__ == "__main__":
test.main()

View File

@ -1078,7 +1078,7 @@ tf_module {
}
member_method {
name: "DenseCountSparseOutput"
argspec: "args=[\'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "DenseToCSRSparseMatrix"
@ -3074,7 +3074,7 @@ tf_module {
}
member_method {
name: "RaggedCountSparseOutput"
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "RaggedCross"
@ -4094,7 +4094,7 @@ tf_module {
}
member_method {
name: "SparseCountSparseOutput"
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "SparseCross"

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "bincount"
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "concat"

View File

@ -1078,7 +1078,7 @@ tf_module {
}
member_method {
name: "DenseCountSparseOutput"
argspec: "args=[\'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "DenseToCSRSparseMatrix"
@ -3074,7 +3074,7 @@ tf_module {
}
member_method {
name: "RaggedCountSparseOutput"
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "RaggedCross"
@ -4094,7 +4094,7 @@ tf_module {
}
member_method {
name: "SparseCountSparseOutput"
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
}
member_method {
name: "SparseCross"

View File

@ -10,7 +10,7 @@ tf_module {
}
member_method {
name: "bincount"
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "concat"