diff --git a/tensorflow/core/api_def/base_api/api_def_DenseCountSparseOutput.pbtxt b/tensorflow/core/api_def/base_api/api_def_DenseCountSparseOutput.pbtxt index 416da1ccaab..8296bfe6d7b 100644 --- a/tensorflow/core/api_def/base_api/api_def_DenseCountSparseOutput.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DenseCountSparseOutput.pbtxt @@ -4,61 +4,62 @@ op { in_arg { name: "values" description: <>; +template +using BatchedMap = std::vector>; namespace { // TODO(momernick): Extend this function to work with outputs of rank > 2. -Status OutputSparse(const BatchedIntMap& per_batch_counts, int num_values, +template +Status OutputSparse(const BatchedMap& per_batch_counts, int num_values, bool is_1d, OpKernelContext* context) { int total_values = 0; int num_batches = per_batch_counts.size(); @@ -44,12 +47,12 @@ Status OutputSparse(const BatchedIntMap& per_batch_counts, int num_values, context->allocate_output(1, TensorShape({total_values}), &values)); auto output_indices = indices->matrix(); - auto output_values = values->flat(); + auto output_values = values->flat(); int64 value_loc = 0; for (int b = 0; b < num_batches; ++b) { const auto& per_batch_count = per_batch_counts[b]; - std::vector> pairs(per_batch_count.begin(), - per_batch_count.end()); + std::vector> pairs(per_batch_count.begin(), + per_batch_count.end()); std::sort(pairs.begin(), pairs.end()); for (const auto& x : pairs) { if (is_1d) { @@ -77,85 +80,19 @@ Status OutputSparse(const BatchedIntMap& per_batch_counts, int num_values, return Status::OK(); } -Status OutputWeightedSparse(const BatchedIntMap& per_batch_counts, - int num_values, const Tensor& weights, bool is_1d, - OpKernelContext* context) { - if (!TensorShapeUtils::IsVector(weights.shape())) { - return errors::InvalidArgument( - "Weights must be a 1-dimensional tensor. Got: ", - weights.shape().DebugString()); - } - - if (num_values > weights.dim_size(0)) { - return errors::InvalidArgument("The maximum array value was ", num_values, - ", but the weight array has size ", - weights.shape().DebugString()); - } - auto weight_values = weights.flat(); - - int total_values = 0; - int num_batches = per_batch_counts.size(); - for (const auto& per_batch_count : per_batch_counts) { - total_values += per_batch_count.size(); - } - - Tensor* indices; - int inner_dim = is_1d ? 1 : 2; - TF_RETURN_IF_ERROR(context->allocate_output( - 0, TensorShape({total_values, inner_dim}), &indices)); - - Tensor* values; - TF_RETURN_IF_ERROR( - context->allocate_output(1, TensorShape({total_values}), &values)); - - auto output_indices = indices->matrix(); - auto output_values = values->flat(); - int64 value_loc = 0; - for (int b = 0; b < num_batches; ++b) { - const auto& per_batch_count = per_batch_counts[b]; - std::vector> pairs(per_batch_count.begin(), - per_batch_count.end()); - std::sort(pairs.begin(), pairs.end()); - for (const auto& x : pairs) { - if (is_1d) { - output_indices(value_loc, 0) = x.first; - } else { - output_indices(value_loc, 0) = b; - output_indices(value_loc, 1) = x.first; - } - output_values(value_loc) = x.second * weight_values(x.first); - ++value_loc; - } - } - - Tensor* dense_shape; - if (is_1d) { - TF_RETURN_IF_ERROR( - context->allocate_output(2, TensorShape({1}), &dense_shape)); - dense_shape->flat().data()[0] = num_values; - } else { - TF_RETURN_IF_ERROR( - context->allocate_output(2, TensorShape({2}), &dense_shape)); - dense_shape->flat().data()[0] = num_batches; - dense_shape->flat().data()[1] = num_values; - } - return Status::OK(); -} - -template -T GetOutputSize(T max_seen, T max_length, T min_length) { +int GetOutputSize(int max_seen, int max_length, int min_length) { return max_length > 0 ? max_length : std::max((max_seen + 1), min_length); } } // namespace -template +template class DenseCount : public OpKernel { public: explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_)); OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_)); - OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_)); + OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_)); } void Compute(OpKernelContext* context) override { @@ -170,6 +107,15 @@ class DenseCount : public OpKernel { "Input must be a 1 or 2-dimensional tensor. Got: ", data.shape().DebugString())); + if (use_weights) { + OP_REQUIRES( + context, weights.shape() == data.shape(), + errors::InvalidArgument( + "Weights and data must have the same shape. Weight shape: ", + weights.shape().DebugString(), + "; data shape: ", data.shape().DebugString())); + } + bool is_1d = TensorShapeUtils::IsVector(data.shape()); int negative_valued_axis = -1; int num_batch_dimensions = (data.shape().dims() + negative_valued_axis); @@ -179,19 +125,23 @@ class DenseCount : public OpKernel { num_batch_elements *= data.shape().dim_size(i); } int num_value_elements = data.shape().num_elements() / num_batch_elements; - auto per_batch_counts = BatchedIntMap(num_batch_elements); + auto per_batch_counts = BatchedMap(num_batch_elements); + T max_value = 0; const auto data_values = data.flat(); + const auto weight_values = weights.flat(); int i = 0; for (int b = 0; b < num_batch_elements; ++b) { for (int v = 0; v < num_value_elements; ++v) { const auto& value = data_values(i); if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) { - if (binary_count_) { - (per_batch_counts[b])[value] = 1; + if (binary_output_) { + per_batch_counts[b][value] = 1; + } else if (use_weights) { + per_batch_counts[b][value] += weight_values(i); } else { - (per_batch_counts[b])[value]++; + per_batch_counts[b][value]++; } if (value > max_value) { max_value = value; @@ -201,30 +151,24 @@ class DenseCount : public OpKernel { } } - T num_output_values = GetOutputSize(max_value, maxlength_, minlength_); - if (use_weights) { - OP_REQUIRES_OK(context, - OutputWeightedSparse(per_batch_counts, num_output_values, - weights, is_1d, context)); - } else { - OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values, - is_1d, context)); - } + int num_output_values = GetOutputSize(max_value, maxlength_, minlength_); + OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values, + is_1d, context)); } private: - T minlength_; - T maxlength_; - bool binary_count_; + int maxlength_; + int minlength_; + bool binary_output_; }; -template +template class SparseCount : public OpKernel { public: explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_)); OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_)); - OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_)); + OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_)); } void Compute(OpKernelContext* context) override { @@ -235,23 +179,27 @@ class SparseCount : public OpKernel { bool use_weights = weights.NumElements() > 0; bool is_1d = shape.NumElements() == 1; - const auto indices_values = indices.matrix(); - const auto values_values = values.flat(); - int num_batches = is_1d ? 1 : shape.flat()(0); int num_values = values.NumElements(); - auto per_batch_counts = BatchedIntMap(num_batches); + const auto indices_values = indices.matrix(); + const auto values_values = values.flat(); + const auto weight_values = weights.flat(); + + auto per_batch_counts = BatchedMap(num_batches); + T max_value = 0; for (int idx = 0; idx < num_values; ++idx) { int batch = is_1d ? 0 : indices_values(idx, 0); const auto& value = values_values(idx); if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) { - if (binary_count_) { - (per_batch_counts[batch])[value] = 1; + if (binary_output_) { + per_batch_counts[batch][value] = 1; + } else if (use_weights) { + per_batch_counts[batch][value] += weight_values(idx); } else { - (per_batch_counts[batch])[value]++; + per_batch_counts[batch][value]++; } if (value > max_value) { max_value = value; @@ -259,30 +207,25 @@ class SparseCount : public OpKernel { } } - T num_output_values = GetOutputSize(max_value, maxlength_, minlength_); - if (use_weights) { - OP_REQUIRES_OK(context, - OutputWeightedSparse(per_batch_counts, num_output_values, - weights, is_1d, context)); - } else { - OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values, - is_1d, context)); - } + int num_output_values = GetOutputSize(max_value, maxlength_, minlength_); + OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values, + is_1d, context)); } private: - T minlength_; - T maxlength_; - bool binary_count_; + int maxlength_; + int minlength_; + bool binary_output_; + bool validate_; }; -template +template class RaggedCount : public OpKernel { public: explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_)); OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_)); - OP_REQUIRES_OK(context, context->GetAttr("binary_count", &binary_count_)); + OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_)); } void Compute(OpKernelContext* context) override { @@ -290,13 +233,15 @@ class RaggedCount : public OpKernel { const Tensor& values = context->input(1); const Tensor& weights = context->input(2); bool use_weights = weights.NumElements() > 0; + bool is_1d = false; const auto splits_values = splits.flat(); const auto values_values = values.flat(); + const auto weight_values = weights.flat(); int num_batches = splits.NumElements() - 1; int num_values = values.NumElements(); - auto per_batch_counts = BatchedIntMap(num_batches); + auto per_batch_counts = BatchedMap(num_batches); T max_value = 0; int batch_idx = 0; @@ -306,10 +251,12 @@ class RaggedCount : public OpKernel { } const auto& value = values_values(idx); if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) { - if (binary_count_) { - (per_batch_counts[batch_idx - 1])[value] = 1; + if (binary_output_) { + per_batch_counts[batch_idx - 1][value] = 1; + } else if (use_weights) { + per_batch_counts[batch_idx - 1][value] += weight_values(idx); } else { - (per_batch_counts[batch_idx - 1])[value]++; + per_batch_counts[batch_idx - 1][value]++; } if (value > max_value) { max_value = value; @@ -317,42 +264,47 @@ class RaggedCount : public OpKernel { } } - T num_output_values = GetOutputSize(max_value, maxlength_, minlength_); - if (use_weights) { - OP_REQUIRES_OK(context, - OutputWeightedSparse(per_batch_counts, num_output_values, - weights, false, context)); - } else { - OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values, - false, context)); - } + int num_output_values = GetOutputSize(max_value, maxlength_, minlength_); + OP_REQUIRES_OK(context, OutputSparse(per_batch_counts, num_output_values, + is_1d, context)); } private: - T minlength_; - T maxlength_; - bool binary_count_; + int maxlength_; + int minlength_; + bool binary_output_; + bool validate_; }; -#define REGISTER(TYPE) \ - \ - REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \ - .TypeConstraint("T") \ - .Device(DEVICE_CPU), \ - DenseCount) \ - \ - REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \ - .TypeConstraint("T") \ - .Device(DEVICE_CPU), \ - SparseCount) \ - \ - REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \ - .TypeConstraint("T") \ - .Device(DEVICE_CPU), \ - RaggedCount) +#define REGISTER_W(W_TYPE) \ + REGISTER(int32, W_TYPE) \ + REGISTER(int64, W_TYPE) -REGISTER(int32); -REGISTER(int64); +#define REGISTER(I_TYPE, W_TYPE) \ + \ + REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \ + .TypeConstraint("T") \ + .TypeConstraint("output_type") \ + .Device(DEVICE_CPU), \ + DenseCount) \ + \ + REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \ + .TypeConstraint("T") \ + .TypeConstraint("output_type") \ + .Device(DEVICE_CPU), \ + SparseCount) \ + \ + REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \ + .TypeConstraint("T") \ + .TypeConstraint("output_type") \ + .Device(DEVICE_CPU), \ + RaggedCount) + +TF_CALL_INTEGRAL_TYPES(REGISTER_W); +TF_CALL_float(REGISTER_W); +TF_CALL_double(REGISTER_W); + +#undef REGISTER_W #undef REGISTER } // namespace tensorflow diff --git a/tensorflow/core/ops/count_ops.cc b/tensorflow/core/ops/count_ops.cc index c9fbe1f8d8e..8de0a2ef954 100644 --- a/tensorflow/core/ops/count_ops.cc +++ b/tensorflow/core/ops/count_ops.cc @@ -19,12 +19,21 @@ limitations under the License. namespace tensorflow { -using shape_inference::DimensionHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; Status DenseCountSparseOutputShapeFn(InferenceContext *c) { - int32 rank = c->Rank(c->input(0)); - DimensionHandle nvals = c->UnknownDim(); + auto values = c->input(0); + auto weights = c->input(1); + ShapeHandle output; + auto num_weights = c->NumElements(weights); + if (c->ValueKnown(num_weights) && c->Value(num_weights) == 0) { + output = values; + } else { + TF_RETURN_IF_ERROR(c->Merge(weights, values, &output)); + } + auto rank = c->Rank(output); + auto nvals = c->UnknownDim(); c->set_output(0, c->Matrix(nvals, rank)); // out.indices c->set_output(1, c->Vector(nvals)); // out.values c->set_output(2, c->Vector(rank)); // out.dense_shape @@ -32,8 +41,8 @@ Status DenseCountSparseOutputShapeFn(InferenceContext *c) { } Status SparseCountSparseOutputShapeFn(InferenceContext *c) { - DimensionHandle rank = c->Dim(c->input(0), 1); - DimensionHandle nvals = c->UnknownDim(); + auto rank = c->Dim(c->input(0), 1); + auto nvals = c->UnknownDim(); c->set_output(0, c->Matrix(nvals, rank)); // out.indices c->set_output(1, c->Vector(nvals)); // out.values c->set_output(2, c->Vector(rank)); // out.dense_shape @@ -45,7 +54,7 @@ Status RaggedCountSparseOutputShapeFn(InferenceContext *c) { if (rank != c->kUnknownRank) { ++rank; // Add the ragged dimension } - DimensionHandle nvals = c->UnknownDim(); + auto nvals = c->UnknownDim(); c->set_output(0, c->Matrix(nvals, rank)); // out.indices c->set_output(1, c->Vector(nvals)); // out.values c->set_output(2, c->Vector(rank)); // out.dense_shape @@ -54,12 +63,12 @@ Status RaggedCountSparseOutputShapeFn(InferenceContext *c) { REGISTER_OP("DenseCountSparseOutput") .Input("values: T") - .Input("weights: float") + .Input("weights: output_type") .Attr("T: {int32, int64}") .Attr("minlength: int >= -1 = -1") .Attr("maxlength: int >= -1 = -1") - .Attr("binary_count: bool") - .Attr("output_type: {int64, float}") + .Attr("binary_output: bool") + .Attr("output_type: {int32, int64, float, double}") .SetShapeFn(DenseCountSparseOutputShapeFn) .Output("output_indices: int64") .Output("output_values: output_type") @@ -69,12 +78,12 @@ REGISTER_OP("SparseCountSparseOutput") .Input("indices: int64") .Input("values: T") .Input("dense_shape: int64") - .Input("weights: float") + .Input("weights: output_type") .Attr("T: {int32, int64}") .Attr("minlength: int >= -1 = -1") .Attr("maxlength: int >= -1 = -1") - .Attr("binary_count: bool") - .Attr("output_type: {int64, float}") + .Attr("binary_output: bool") + .Attr("output_type: {int32, int64, float, double}") .SetShapeFn(SparseCountSparseOutputShapeFn) .Output("output_indices: int64") .Output("output_values: output_type") @@ -83,12 +92,12 @@ REGISTER_OP("SparseCountSparseOutput") REGISTER_OP("RaggedCountSparseOutput") .Input("splits: int64") .Input("values: T") - .Input("weights: float") + .Input("weights: output_type") .Attr("T: {int32, int64}") .Attr("minlength: int >= -1 = -1") .Attr("maxlength: int >= -1 = -1") - .Attr("binary_count: bool") - .Attr("output_type: {int64, float}") + .Attr("binary_output: bool") + .Attr("output_type: {int32, int64, float, double}") .SetShapeFn(RaggedCountSparseOutputShapeFn) .Output("output_indices: int64") .Output("output_values: output_type") diff --git a/tensorflow/python/ops/bincount.py b/tensorflow/python/ops/bincount.py index e1b3bebaaaa..68950eaf596 100644 --- a/tensorflow/python/ops/bincount.py +++ b/tensorflow/python/ops/bincount.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_count_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util.tf_export import tf_export @@ -33,7 +33,7 @@ def sparse_bincount(values, axis=0, minlength=None, maxlength=None, - binary_count=False, + binary_output=False, name=None): """Count the number of times an integer value appears in a tensor. @@ -58,8 +58,9 @@ def sparse_bincount(values, maxlength: If given, skips `values` that are greater than or equal to `maxlength`, and ensures that the output has a `dense_shape` of at most `maxlength` in the inner dimension. - binary_count: Whether to do a binary count. When True, this op will return 1 - for any value that exists instead of counting the number of occurrences. + binary_output: If True, this op will output 1 instead of the number of times + a token appears (equivalent to one_hot + reduce_any instead of one_hot + + reduce_add). Defaults to False. name: A name for this op. Returns: @@ -78,7 +79,7 @@ def sparse_bincount(values, SparseTensor) and returns a SparseTensor where the value of (i,j) is the number of times value j appears in batch i. - >>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]] + >>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64) >>> output = tf.sparse.bincount(data, axis=-1) >>> print(output) SparseTensor(indices=tf.Tensor( @@ -102,7 +103,7 @@ def sparse_bincount(values, dense shape is [2, 500] instead of [2,10002] or [2, 102]. >>> minlength = maxlength = 500 - >>> data = [[10, 20, 30, 20], [11, 101, 11, 10001]] + >>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64) >>> output = tf.sparse.bincount( ... data, axis=-1, minlength=minlength, maxlength=maxlength) >>> print(output) @@ -123,8 +124,8 @@ def sparse_bincount(values, some values (like 20 in batch 1 and 11 in batch 2) appear more than once, the 'values' tensor is all 1s. - >>> dense = [[10, 20, 30, 20], [11, 101, 11, 10001]] - >>> output = tf.sparse.bincount(dense, binary_count=True, axis=-1) + >>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64) + >>> output = tf.sparse.bincount(data, binary_output=True, axis=-1) >>> print(output) SparseTensor(indices=tf.Tensor( [[ 0 10] @@ -136,20 +137,42 @@ def sparse_bincount(values, values=tf.Tensor([1 1 1 1 1 1], shape=(6,), dtype=int64), dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64)) + **Weighted bin-counting** + + This example takes two inputs - a values tensor and a weights tensor. These + tensors must be identically shaped, and have the same row splits or indices + in the case of RaggedTensors or SparseTensors. When performing a weighted + count, the op will output a SparseTensor where the value of (i, j) is the + sum of the values in the weight tensor's batch i in the locations where + the values tensor has the value j. In this case, the output dtype is the + same as the dtype of the weights tensor. + + >>> data = np.array([[10, 20, 30, 20], [11, 101, 11, 10001]], dtype=np.int64) + >>> weights = [[2, 0.25, 15, 0.5], [2, 17, 3, 0.9]] + >>> output = tf.sparse.bincount(data, weights=weights, axis=-1) + >>> print(output) + SparseTensor(indices=tf.Tensor( + [[ 0 10] + [ 0 20] + [ 0 30] + [ 1 11] + [ 1 101] + [ 1 10001]], shape=(6, 2), dtype=int64), + values=tf.Tensor([2. 0.75 15. 5. 17. 0.9], shape=(6,), dtype=float32), + dense_shape=tf.Tensor([ 2 10002], shape=(2,), dtype=int64)) + """ with ops.name_scope(name, "count", [values, weights]): if not isinstance(values, sparse_tensor.SparseTensor): values = ragged_tensor.convert_to_tensor_or_ragged_tensor( values, name="values") + if weights is not None: + if not isinstance(weights, sparse_tensor.SparseTensor): + weights = ragged_tensor.convert_to_tensor_or_ragged_tensor( + weights, name="weights") - if weights is not None and binary_count: - raise ValueError("binary_count and weights are mutually exclusive.") - - if weights is None: - weights = [] - output_type = dtypes.int64 - else: - output_type = dtypes.float32 + if weights is not None and binary_output: + raise ValueError("binary_output and weights are mutually exclusive.") if axis is None: axis = 0 @@ -162,38 +185,114 @@ def sparse_bincount(values, maxlength_value = maxlength if maxlength is not None else -1 if axis == 0: - if isinstance(values, - (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): + if isinstance(values, sparse_tensor.SparseTensor): + if weights is not None: + weights = validate_sparse_weights(values, weights) + values = values.values + elif isinstance(values, ragged_tensor.RaggedTensor): + if weights is not None: + weights = validate_ragged_weights(values, weights) values = values.values else: + if weights is not None: + weights = array_ops.reshape(weights, [-1]) values = array_ops.reshape(values, [-1]) if isinstance(values, sparse_tensor.SparseTensor): + weights = validate_sparse_weights(values, weights) c_ind, c_val, c_shape = gen_count_ops.sparse_count_sparse_output( values.indices, values.values, values.dense_shape, - weights=weights, + weights, minlength=minlength_value, maxlength=maxlength_value, - binary_count=binary_count, - output_type=output_type) + binary_output=binary_output) elif isinstance(values, ragged_tensor.RaggedTensor): + weights = validate_ragged_weights(values, weights) c_ind, c_val, c_shape = gen_count_ops.ragged_count_sparse_output( values.row_splits, values.values, - weights=weights, + weights, minlength=minlength_value, maxlength=maxlength_value, - binary_count=binary_count, - output_type=output_type) + binary_output=binary_output) else: + weights = validate_dense_weights(values, weights) c_ind, c_val, c_shape = gen_count_ops.dense_count_sparse_output( values, weights=weights, minlength=minlength_value, maxlength=maxlength_value, - binary_count=binary_count, - output_type=output_type) + binary_output=binary_output) return sparse_tensor.SparseTensor(c_ind, c_val, c_shape) + + +def validate_dense_weights(values, weights): + """Validates the passed weight tensor or creates an empty one.""" + if weights is None: + return array_ops.constant([], dtype=values.dtype) + + if not isinstance(weights, ops.Tensor): + raise ValueError( + "`weights` must be a tf.Tensor if `values` is a tf.Tensor.") + + return weights + + +def validate_sparse_weights(values, weights): + """Validates the passed weight tensor or creates an empty one.""" + if weights is None: + return array_ops.constant([], dtype=values.values.dtype) + + if not isinstance(weights, sparse_tensor.SparseTensor): + raise ValueError( + "`weights` must be a SparseTensor if `values` is a SparseTensor.") + + checks = [] + if weights.dense_shape is not values.dense_shape: + checks.append( + check_ops.assert_equal( + weights.dense_shape, + values.dense_shape, + message="'weights' and 'values' must have the same dense shape.")) + if weights.indices is not values.indices: + checks.append( + check_ops.assert_equal( + weights.indices, + values.indices, + message="'weights' and 'values' must have the same indices.") + ) + if checks: + with ops.control_dependencies(checks): + weights = array_ops.identity(weights.values) + else: + weights = weights.values + + return weights + + +def validate_ragged_weights(values, weights): + """Validates the passed weight tensor or creates an empty one.""" + if weights is None: + return array_ops.constant([], dtype=values.values.dtype) + + if not isinstance(weights, ragged_tensor.RaggedTensor): + raise ValueError( + "`weights` must be a RaggedTensor if `values` is a RaggedTensor.") + + checks = [] + if weights.row_splits is not values.row_splits: + checks.append( + check_ops.assert_equal( + weights.row_splits, + values.row_splits, + message="'weights' and 'values' must have the same row splits.")) + if checks: + with ops.control_dependencies(checks): + weights = array_ops.identity(weights.values) + else: + weights = weights.values + + return weights diff --git a/tensorflow/python/ops/bincount_test.py b/tensorflow/python/ops/bincount_test.py index 776b65b72d0..839af8dcc35 100644 --- a/tensorflow/python/ops/bincount_test.py +++ b/tensorflow/python/ops/bincount_test.py @@ -21,6 +21,8 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.eager import context +from tensorflow.python.framework import errors from tensorflow.python.ops import bincount from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -65,7 +67,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]], "expected_values": [1, 1, 1, 1, 1], "expected_shape": [2, 6], - "binary_count": True, + "binary_output": True, }, { "testcase_name": "_maxlength_binary", "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32), @@ -73,7 +75,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]], "expected_values": [1, 1, 1, 1, 1], "expected_shape": [2, 7], - "binary_count": True, + "binary_output": True, }, { "testcase_name": "_minlength_binary", "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32), @@ -82,7 +84,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): [1, 7]], "expected_values": [1, 1, 1, 1, 1, 1, 1], "expected_shape": [2, 9], - "binary_count": True, + "binary_output": True, }, { "testcase_name": "_minlength_larger_values_binary", "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32), @@ -91,40 +93,40 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): [1, 7]], "expected_values": [1, 1, 1, 1, 1, 1, 1], "expected_shape": [2, 8], - "binary_count": True, + "binary_output": True, }, { "testcase_name": "_no_maxlength_weights", "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32), "expected_indices": [[0, 1], [0, 2], [0, 3], [1, 4], [1, 5]], - "expected_values": [1, 2, 3, 8, 5], + "expected_values": [2, 1, 0.5, 9, 3], "expected_shape": [2, 6], - "weights": [0.5, 1, 2, 3, 4, 5] + "weights": [[0.5, 1, 2], [3, 4, 5]] }, { "testcase_name": "_maxlength_weights", "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32), "maxlength": 7, "expected_indices": [[0, 1], [0, 2], [0, 3], [1, 0], [1, 4]], - "expected_values": [1, 2, 3, 0.5, 8], + "expected_values": [2, 1, 0.5, 3, 9], "expected_shape": [2, 7], - "weights": [0.5, 1, 2, 3, 4, 5, 6] + "weights": [[0.5, 1, 2, 11], [7, 3, 4, 5]] }, { "testcase_name": "_minlength_weights", "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32), "minlength": 9, "expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4], [1, 7]], - "expected_values": [1, 2, 3, 7, 0.5, 8, 7], + "expected_values": [2, 1, 0.5, 3, 5, 13, 4], "expected_shape": [2, 9], - "weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8] + "weights": [[0.5, 1, 2, 3], [4, 5, 6, 7]] }, { "testcase_name": "_minlength_larger_values_weights", "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32), "minlength": 3, "expected_indices": [[0, 1], [0, 2], [0, 3], [0, 7], [1, 0], [1, 4], [1, 7]], - "expected_values": [1, 2, 3, 7, 0.5, 8, 7], + "expected_values": [2, 1, 0.5, 3, 5, 13, 4], "expected_shape": [2, 8], - "weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8] + "weights": [[0.5, 1, 2, 3], [4, 5, 6, 7]] }, { "testcase_name": "_1d", "x": np.array([3, 2, 1, 1], dtype=np.int32), @@ -146,7 +148,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): expected_shape, minlength=None, maxlength=None, - binary_count=False, + binary_output=False, weights=None, axis=-1): y = bincount.sparse_bincount( @@ -154,7 +156,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): weights=weights, minlength=minlength, maxlength=maxlength, - binary_count=binary_count, + binary_output=binary_output, axis=axis) self.assertAllEqual(expected_indices, y.indices) self.assertAllEqual(expected_values, y.values) @@ -216,7 +218,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]], "expected_values": [1, 1, 1, 1], "expected_shape": [3, 6], - "binary_count": + "binary_output": True, }, { @@ -230,7 +232,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_shape": [3, 7], "maxlength": 7, - "binary_count": + "binary_output": True, }, { @@ -244,7 +246,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_shape": [3, 9], "minlength": 9, - "binary_count": + "binary_output": True, }, { @@ -258,7 +260,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_shape": [3, 8], "minlength": 3, - "binary_count": + "binary_output": True, }, { @@ -268,9 +270,10 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32), "expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]], - "expected_values": [1, 3, 8, 5], + "expected_values": [2, 6, 7, 10], "expected_shape": [3, 6], - "weights": [0.5, 1, 2, 3, 4, 5] + "weights": + np.array([[6, 0, 2, 0], [0, 0, 0, 0], [10, 0, 3.5, 3.5]]), }, { "testcase_name": @@ -279,11 +282,12 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): np.array([[3, 0, 1, 0], [0, 0, 7, 0], [5, 0, 4, 4]], dtype=np.int32), "expected_indices": [[0, 1], [0, 3], [2, 4], [2, 5]], - "expected_values": [1, 3, 8, 5], + "expected_values": [2, 6, 7, 10], "expected_shape": [3, 7], "maxlength": 7, - "weights": [0.5, 1, 2, 3, 4, 5, 6] + "weights": + np.array([[6, 0, 2, 0], [0, 0, 14, 0], [10, 0, 3.5, 3.5]]), }, { "testcase_name": @@ -292,11 +296,12 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32), "expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]], - "expected_values": [1, 3, 7, 8, 5], + "expected_values": [2, 6, 14, 6.5, 10], "expected_shape": [3, 9], "minlength": 9, - "weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8] + "weights": + np.array([[6, 0, 2, 0], [14, 0, 0, 0], [10, 0, 3, 3.5]]), }, { "testcase_name": @@ -305,11 +310,12 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): np.array([[3, 0, 1, 0], [7, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32), "expected_indices": [[0, 1], [0, 3], [1, 7], [2, 4], [2, 5]], - "expected_values": [1, 3, 7, 8, 5], + "expected_values": [2, 6, 14, 6.5, 10], "expected_shape": [3, 8], "minlength": 3, - "weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8] + "weights": + np.array([[6, 0, 2, 0], [14, 0, 0, 0], [10, 0, 3, 3.5]]), }, { "testcase_name": "_1d", @@ -338,16 +344,17 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): expected_shape, maxlength=None, minlength=None, - binary_count=False, + binary_output=False, weights=None, axis=-1): x_sparse = sparse_ops.from_dense(x) + w_sparse = sparse_ops.from_dense(weights) if weights is not None else None y = bincount.sparse_bincount( x_sparse, - weights=weights, + weights=w_sparse, minlength=minlength, maxlength=maxlength, - binary_count=binary_count, + binary_output=binary_output, axis=axis) self.assertAllEqual(expected_indices, y.indices) self.assertAllEqual(expected_values, y.values) @@ -393,7 +400,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]], "expected_values": [1, 1, 1, 1, 1, 1], "expected_shape": [5, 6], - "binary_count": True, + "binary_output": True, }, { "testcase_name": "_maxlength_binary", @@ -402,7 +409,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]], "expected_values": [1, 1, 1, 1, 1, 1], "expected_shape": [5, 7], - "binary_count": True, + "binary_output": True, }, { "testcase_name": "_minlength_binary", @@ -412,13 +419,13 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): [4, 5]], "expected_values": [1, 1, 1, 1, 1, 1, 1], "expected_shape": [5, 9], - "binary_count": True, + "binary_output": True, }, { "testcase_name": "_minlength_larger_values_binary", "x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]], "minlength": 3, - "binary_count": True, + "binary_output": True, "expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4], [4, 5]], "expected_values": [1, 1, 1, 1, 1, 1, 1], @@ -428,18 +435,18 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "testcase_name": "_no_maxlength_weights", "x": [[], [], [3, 0, 1], [], [5, 0, 4, 4]], "expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]], - "expected_values": [0.5, 1, 3, 0.5, 8, 5], + "expected_values": [0.5, 2, 6, 0.25, 8, 10], "expected_shape": [5, 6], - "weights": [0.5, 1, 2, 3, 4, 5] + "weights": [[], [], [6, 0.5, 2], [], [10, 0.25, 5, 3]], }, { "testcase_name": "_maxlength_weights", "x": [[], [], [3, 0, 1], [7], [5, 0, 4, 4]], "maxlength": 7, "expected_indices": [[2, 0], [2, 1], [2, 3], [4, 0], [4, 4], [4, 5]], - "expected_values": [0.5, 1, 3, 0.5, 8, 5], + "expected_values": [0.5, 2, 6, 0.25, 8, 10], "expected_shape": [5, 7], - "weights": [0.5, 1, 2, 3, 4, 5, 6] + "weights": [[], [], [6, 0.5, 2], [14], [10, 0.25, 5, 3]], }, { "testcase_name": "_minlength_weights", @@ -447,9 +454,9 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "minlength": 9, "expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4], [4, 5]], - "expected_values": [0.5, 1, 3, 7, 0.5, 8, 5], + "expected_values": [0.5, 2, 6, 14, 0.25, 8, 10], "expected_shape": [5, 9], - "weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8] + "weights": [[], [], [6, 0.5, 2], [14], [10, 0.25, 5, 3]], }, { "testcase_name": "_minlength_larger_values_weights", @@ -457,9 +464,9 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): "minlength": 3, "expected_indices": [[2, 0], [2, 1], [2, 3], [3, 7], [4, 0], [4, 4], [4, 5]], - "expected_values": [0.5, 1, 3, 7, 0.5, 8, 5], + "expected_values": [0.5, 2, 6, 14, 0.25, 8, 10], "expected_shape": [5, 8], - "weights": [0.5, 1, 2, 3, 4, 5, 6, 7, 8] + "weights": [[], [], [6, 0.5, 2], [14], [10, 0.25, 5, 3]], }, { "testcase_name": "_1d", @@ -484,21 +491,114 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): expected_shape, maxlength=None, minlength=None, - binary_count=False, + binary_output=False, weights=None, axis=-1): x_ragged = ragged_factory_ops.constant(x) + w = ragged_factory_ops.constant(weights) if weights is not None else None y = bincount.sparse_bincount( x_ragged, - weights=weights, + weights=w, minlength=minlength, maxlength=maxlength, - binary_count=binary_count, + binary_output=binary_output, axis=axis) self.assertAllEqual(expected_indices, y.indices) self.assertAllEqual(expected_values, y.values) self.assertAllEqual(expected_shape, y.dense_shape) +class TestSparseCountFailureModes(test.TestCase): + + def test_dense_input_sparse_weights_fails(self): + x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) + weights = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_dense_input_ragged_weights_fails(self): + x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) + weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]]) + with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_dense_input_wrong_shape_fails(self): + x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) + weights = np.array([[3, 2], [5, 4], [4, 3]]) + # Note: Eager mode and graph mode throw different errors here. Graph mode + # will fail with a ValueError from the shape checking logic, while Eager + # will fail with an InvalidArgumentError from the kernel itself. + if context.executing_eagerly(): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must have the same shape"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + else: + with self.assertRaisesRegexp(ValueError, "both shapes must be equal"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_sparse_input_dense_weights_fails(self): + x = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) + with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_sparse_input_ragged_weights_fails(self): + x = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]]) + with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_sparse_input_wrong_indices_fails(self): + x = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + weights = sparse_ops.from_dense( + np.array([[3, 1, 0, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must have the same indices"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_sparse_input_too_many_indices_fails(self): + x = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + weights = sparse_ops.from_dense( + np.array([[3, 1, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Incompatible shapes"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_sparse_input_wrong_shape_fails(self): + x = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + weights = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4], [0, 0, 0, 0]], + dtype=np.int32)) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must have the same dense shape"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_ragged_input_dense_weights_fails(self): + x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]]) + weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) + with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_ragged_input_sparse_weights_fails(self): + x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]]) + weights = sparse_ops.from_dense( + np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) + with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + def test_ragged_input_different_shape_fails(self): + x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]]) + weights = ragged_factory_ops.constant([[6, 0.5, 2], [], [10, 0.25, 5, 3]]) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must have the same row splits"): + self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 05b8842be66..44fb74ac63a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1078,7 +1078,7 @@ tf_module { } member_method { name: "DenseCountSparseOutput" - argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " + argspec: "args=[\'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " } member_method { name: "DenseToCSRSparseMatrix" @@ -3074,7 +3074,7 @@ tf_module { } member_method { name: "RaggedCountSparseOutput" - argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " + argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " } member_method { name: "RaggedCross" @@ -4094,7 +4094,7 @@ tf_module { } member_method { name: "SparseCountSparseOutput" - argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " + argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " } member_method { name: "SparseCross" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt index 4c4f6c62291..f8f8edb26a8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt @@ -14,7 +14,7 @@ tf_module { } member_method { name: "bincount" - argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "concat" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 05b8842be66..44fb74ac63a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1078,7 +1078,7 @@ tf_module { } member_method { name: "DenseCountSparseOutput" - argspec: "args=[\'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " + argspec: "args=[\'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " } member_method { name: "DenseToCSRSparseMatrix" @@ -3074,7 +3074,7 @@ tf_module { } member_method { name: "RaggedCountSparseOutput" - argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " + argspec: "args=[\'splits\', \'values\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " } member_method { name: "RaggedCross" @@ -4094,7 +4094,7 @@ tf_module { } member_method { name: "SparseCountSparseOutput" - argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_count\', \'output_type\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " + argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'weights\', \'binary_output\', \'minlength\', \'maxlength\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], " } member_method { name: "SparseCross" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt index a9ad81920dd..67235bb2cf2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt @@ -10,7 +10,7 @@ tf_module { } member_method { name: "bincount" - argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'values\', \'weights\', \'axis\', \'minlength\', \'maxlength\', \'binary_output\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "concat"