Update attr name for Dense version Bincount.
PiperOrigin-RevId: 311423709 Change-Id: Ief7c901477be8e06b1d3f98613c7390c12e9680b
This commit is contained in:
parent
0ac3572e8d
commit
0c9e56e931
@ -28,7 +28,7 @@ The counts or summed weights for each value in the range [0, size).
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "binary_count"
|
||||
name: "binary_output"
|
||||
description: <<END
|
||||
bool; Whether the kernel should count the appearance or number of occurrences.
|
||||
END
|
||||
|
@ -34,7 +34,7 @@ The counts or summed weights for each value in the range [0, size).
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "binary_count"
|
||||
name: "binary_output"
|
||||
description: <<END
|
||||
bool; Whether the kernel should count the appearance or number of occurrences.
|
||||
END
|
||||
|
@ -40,7 +40,7 @@ The counts or summed weights for each value in the range [0, size).
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "binary_count"
|
||||
name: "binary_output"
|
||||
description: <<END
|
||||
bool; Whether the kernel should count the appearance or number of occurrences.
|
||||
END
|
||||
|
@ -130,8 +130,8 @@ struct BincountFunctor<CPUDevice, Tidx, T, false> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tidx, typename T, bool binary_count>
|
||||
struct BincountReduceFunctor<CPUDevice, Tidx, T, binary_count> {
|
||||
template <typename Tidx, typename T, bool binary_output>
|
||||
struct BincountReduceFunctor<CPUDevice, Tidx, T, binary_output> {
|
||||
static Status Compute(OpKernelContext* context,
|
||||
const typename TTypes<Tidx, 2>::ConstTensor& in,
|
||||
const typename TTypes<T, 2>::ConstTensor& weights,
|
||||
@ -148,7 +148,7 @@ struct BincountReduceFunctor<CPUDevice, Tidx, T, binary_count> {
|
||||
for (int64 j = 0; j < num_cols; ++j) {
|
||||
Tidx value = in(i, j);
|
||||
if (value < num_bins) {
|
||||
if (binary_count) {
|
||||
if (binary_output) {
|
||||
out(i, value) = T(1);
|
||||
} else {
|
||||
if (weights.size()) {
|
||||
@ -221,7 +221,7 @@ template <typename Device, typename Tidx, typename T>
|
||||
class DenseBincountOp : public OpKernel {
|
||||
public:
|
||||
explicit DenseBincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_count", &binary_count_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_output", &binary_output_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -240,7 +240,7 @@ class DenseBincountOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({size}), &out_t));
|
||||
auto out = out_t->flat<T>();
|
||||
fill(ctx->eigen_device<Device>(), out);
|
||||
if (binary_count_) {
|
||||
if (binary_output_) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, functor::BincountFunctor<Device, Tidx, T, true>::Compute(
|
||||
ctx, data.flat<Tidx>(), weights.flat<T>(), out, size));
|
||||
@ -259,7 +259,7 @@ class DenseBincountOp : public OpKernel {
|
||||
ctx, ctx->allocate_output(0, TensorShape({num_rows, size}), &out_t));
|
||||
auto out = out_t->matrix<T>();
|
||||
fill(ctx->eigen_device<Device>(), out_t->flat<T>());
|
||||
if (binary_count_) {
|
||||
if (binary_output_) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, functor::BincountReduceFunctor<Device, Tidx, T, true>::Compute(
|
||||
ctx, data.matrix<Tidx>(), weight_matrix, out, size));
|
||||
@ -273,7 +273,7 @@ class DenseBincountOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
bool binary_count_;
|
||||
bool binary_output_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(Tidx, T) \
|
||||
@ -314,7 +314,7 @@ template <typename Device, typename Tidx, typename T>
|
||||
class SparseBincountOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseBincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_count", &binary_count_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_output", &binary_output_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -338,7 +338,7 @@ class SparseBincountOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({size}), &out_t));
|
||||
auto out = out_t->flat<T>();
|
||||
fill(ctx->eigen_device<Device>(), out);
|
||||
if (binary_count_) {
|
||||
if (binary_output_) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
functor::BincountFunctor<Device, Tidx, T, true>::Compute(
|
||||
ctx, values, weights, out, size));
|
||||
@ -359,7 +359,7 @@ class SparseBincountOp : public OpKernel {
|
||||
const int64 batch = indices_mat(i, 0);
|
||||
const Tidx bin = values(i);
|
||||
if (bin < size) {
|
||||
if (binary_count_) {
|
||||
if (binary_output_) {
|
||||
out(batch, bin) = T(1);
|
||||
} else {
|
||||
if (weights_size) {
|
||||
@ -374,7 +374,7 @@ class SparseBincountOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
bool binary_count_;
|
||||
bool binary_output_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(Tidx, T) \
|
||||
@ -395,7 +395,7 @@ template <typename Device, typename Tidx, typename T>
|
||||
class RaggedBincountOp : public OpKernel {
|
||||
public:
|
||||
explicit RaggedBincountOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_count", &binary_count_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_output", &binary_output_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -429,7 +429,7 @@ class RaggedBincountOp : public OpKernel {
|
||||
OP_REQUIRES(ctx, bin >= 0,
|
||||
errors::InvalidArgument("Input must be non-negative"));
|
||||
if (bin < size) {
|
||||
if (binary_count_) {
|
||||
if (binary_output_) {
|
||||
out(batch_idx - 1, bin) = T(1);
|
||||
} else {
|
||||
T value = (weights_size > 0) ? weights(idx) : T(1);
|
||||
@ -440,7 +440,7 @@ class RaggedBincountOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
bool binary_count_;
|
||||
bool binary_output_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(Tidx, T) \
|
||||
|
@ -1657,7 +1657,7 @@ REGISTER_OP("DenseBincount")
|
||||
.Input("weights: T")
|
||||
.Attr("Tidx: {int32, int64}")
|
||||
.Attr("T: {int32, int64, float32, float64}")
|
||||
.Attr("binary_count: bool = false")
|
||||
.Attr("binary_output: bool = false")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
@ -1704,7 +1704,7 @@ REGISTER_OP("SparseBincount")
|
||||
.Input("weights: T")
|
||||
.Attr("Tidx: {int32, int64}")
|
||||
.Attr("T: {int32, int64, float32, float64}")
|
||||
.Attr("binary_count: bool = false")
|
||||
.Attr("binary_output: bool = false")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
const Tensor* size_tensor = c->input_tensor(3);
|
||||
@ -1754,7 +1754,7 @@ REGISTER_OP("RaggedBincount")
|
||||
.Input("weights: T")
|
||||
.Attr("Tidx: {int32, int64}")
|
||||
.Attr("T: {int32, int64, float32, float64}")
|
||||
.Attr("binary_count: bool = false")
|
||||
.Attr("binary_output: bool = false")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
|
@ -183,7 +183,7 @@ class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=[], size=size, binary_count=True)))
|
||||
input=inp, weights=[], size=size, binary_output=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
@ -201,7 +201,7 @@ class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=np_weight, size=size, binary_count=True)))
|
||||
input=inp, weights=np_weight, size=size, binary_output=True)))
|
||||
|
||||
def _test_bincount_col_count(self, num_rows, num_cols, size, dtype):
|
||||
np.random.seed(42)
|
||||
@ -230,7 +230,7 @@ class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
np_out,
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
input=inp, weights=[], size=size, binary_count=True)))
|
||||
input=inp, weights=[], size=size, binary_output=True)))
|
||||
|
||||
def _test_bincount_col_count_with_weights(self, num_rows, num_cols, size,
|
||||
dtype):
|
||||
@ -401,7 +401,7 @@ class SparseBincountOpTest(test_util.TensorFlowTestCase,
|
||||
dense_shape=[num_rows],
|
||||
size=size,
|
||||
weights=[],
|
||||
binary_count=True)))
|
||||
binary_output=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
@ -427,7 +427,7 @@ class SparseBincountOpTest(test_util.TensorFlowTestCase,
|
||||
dense_shape=[num_rows],
|
||||
size=size,
|
||||
weights=inp_weight,
|
||||
binary_count=True)))
|
||||
binary_output=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
@ -490,7 +490,7 @@ class SparseBincountOpTest(test_util.TensorFlowTestCase,
|
||||
dense_shape=inp_sparse.dense_shape,
|
||||
size=size,
|
||||
weights=[],
|
||||
binary_count=True)))
|
||||
binary_output=True)))
|
||||
|
||||
|
||||
class RaggedBincountOpTest(test_util.TensorFlowTestCase,
|
||||
@ -530,7 +530,7 @@ class RaggedBincountOpTest(test_util.TensorFlowTestCase,
|
||||
values=x.values,
|
||||
weights=[],
|
||||
size=6,
|
||||
binary_count=True)))
|
||||
binary_output=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
@ -629,7 +629,7 @@ class RaggedBincountOpTest(test_util.TensorFlowTestCase,
|
||||
values=x.values,
|
||||
weights=[],
|
||||
size=size,
|
||||
binary_count=True)))
|
||||
binary_output=True)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1074,7 +1074,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "DenseBincount"
|
||||
argspec: "args=[\'input\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'size\', \'weights\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DenseCountSparseOutput"
|
||||
@ -3070,7 +3070,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedBincount"
|
||||
argspec: "args=[\'splits\', \'values\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'splits\', \'values\', \'size\', \'weights\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedCountSparseOutput"
|
||||
@ -4082,7 +4082,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "SparseBincount"
|
||||
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'size\', \'weights\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseConcat"
|
||||
|
@ -1074,7 +1074,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "DenseBincount"
|
||||
argspec: "args=[\'input\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'size\', \'weights\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DenseCountSparseOutput"
|
||||
@ -3070,7 +3070,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedBincount"
|
||||
argspec: "args=[\'splits\', \'values\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'splits\', \'values\', \'size\', \'weights\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedCountSparseOutput"
|
||||
@ -4082,7 +4082,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "SparseBincount"
|
||||
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'size\', \'weights\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'size\', \'weights\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SparseConcat"
|
||||
|
Loading…
Reference in New Issue
Block a user