Added new op: tf.strings.unsorted_segment_join.

PiperOrigin-RevId: 251690479
This commit is contained in:
Mehrdad Khatir 2019-06-05 12:05:56 -07:00 committed by TensorFlower Gardener
parent bb5ae484d0
commit f0f318aa40
14 changed files with 626 additions and 35 deletions

View File

@ -0,0 +1,62 @@
op {
graph_op_name: "UnsortedSegmentJoin"
in_arg {
name: "inputs"
description: <<END
The input to be joined.
END
}
in_arg {
name: "segment_ids"
description: <<END
A tensor whose shape is a prefix of data.shape. Negative segment ids are not
supported.
END
}
in_arg {
name: "num_segments"
description: <<END
A scalar.
END
}
out_arg {
name: "output"
description: <<END
END
}
attr {
name: "separator"
description: <<END
The separator to use when joining.
END
}
summary: "Joins the elements of `inputs` based on `segment_ids`."
description: <<END
Computes the string join along segments of a tensor.
Given `segment_ids` with rank `N` and `data` with rank `N+M`:
`output[i, k1...kM] = strings.join([data[j1...jN, k1...kM])`
where the join is over all [j1...jN] such that segment_ids[j1...jN] = i.
Strings are joined in row-major order.
For example:
```python
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
output_array = string_ops.unsorted_segment_join(inputs=inputs,
segment_ids=[1, 0, 1],
num_segments=2,
separator=':'))
# output_array ==> [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
inputs = ['this', 'is', 'a', 'test']
output_array = string_ops.unsorted_segment_join(inputs=inputs,
segment_ids=[0, 0, 0, 0],
num_segments=1,
separator=':'))
# output_array ==> ['this:is:a:test']
```
END
}

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "UnsortedSegmentJoin"
endpoint {
name: "strings.unsorted_segment_join"
}
}

View File

@ -1443,6 +1443,37 @@ Status RandomShape(shape_inference::InferenceContext* c) {
return Status::OK(); return Status::OK();
} }
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
ShapeHandle s_data = c->input(0);
ShapeHandle s_segment_ids = c->input(1);
ShapeHandle s_num_segments = c->input(2);
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
ShapeHandle out;
// Leading dimensions of data must be compatible with dimensions of
// <s_segment_ids>.
if (c->RankKnown(s_segment_ids)) {
TF_RETURN_IF_ERROR(
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
// Get the value of the num_segments input tensor.
DimensionHandle num_segments_dim;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
// Output is {segment_id_rank} + s_data[segment_id_rank:].
ShapeHandle s_data_suffix;
TF_RETURN_IF_ERROR(
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
TF_RETURN_IF_ERROR(
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
} else {
out = c->UnknownShape();
}
c->set_output(0, out);
return Status::OK();
}
namespace { namespace {
// This SliceHelper processes the output shape of the `slice` // This SliceHelper processes the output shape of the `slice`

View File

@ -276,6 +276,9 @@ Status UnknownShape(shape_inference::InferenceContext* c);
// Shape function for reduction operations. // Shape function for reduction operations.
Status ReductionShape(shape_inference::InferenceContext* c); Status ReductionShape(shape_inference::InferenceContext* c);
// Shape function for unsorted segment operations.
Status UnsortedSegmentReductionShapeFn(InferenceContext* c);
// Shape function for concat operations. // Shape function for concat operations.
// <num_inputs_to_concat> is the number of inputs to concatenate and are taken // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
// from inputs // from inputs

View File

@ -5183,6 +5183,7 @@ cc_library(
":substr_op", ":substr_op",
":unicode_ops", ":unicode_ops",
":unicode_script_op", ":unicode_script_op",
":unsorted_segment_join_op",
], ],
) )
@ -5219,6 +5220,12 @@ tf_kernel_library(
deps = STRING_DEPS, deps = STRING_DEPS,
) )
tf_kernel_library(
name = "unsorted_segment_join_op",
prefix = "unsorted_segment_join_op",
deps = STRING_DEPS,
)
tf_kernel_library( tf_kernel_library(
name = "string_format_op", name = "string_format_op",
prefix = "string_format_op", prefix = "string_format_op",

View File

@ -0,0 +1,166 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/string_ops.cc.
#include <string>
#include <utility>
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace {
template <typename INDICES_TYPE>
gtl::InlinedVector<INDICES_TYPE, 8> GetFlattenedRelativeOffsets(
INDICES_TYPE small_stride, INDICES_TYPE big_stride) {
gtl::InlinedVector<INDICES_TYPE, 8> flattened_offsets(small_stride);
for (auto i = 0; i < small_stride; i++) {
flattened_offsets[i] = i * big_stride;
}
return flattened_offsets;
}
template <typename INDICES_TYPE>
std::pair<INDICES_TYPE, INDICES_TYPE> GetStrides(
const TensorShape& input_shape, const TensorShape& segment_id_shape) {
int64 small_stride = 1;
int64 big_stride = 1;
for (auto i = 0; i < input_shape.dims(); i++) {
if (i < segment_id_shape.dims()) {
small_stride *= segment_id_shape.dim_size(i);
} else {
big_stride *= input_shape.dim_size(i);
}
}
return std::make_pair(big_stride, small_stride);
}
TensorShape GetOutputShape(const TensorShape& input_shape,
const TensorShape& segment_id_shape,
const int64 num_segments) {
TensorShape output_shape;
output_shape.AddDim(num_segments);
for (size_t index = segment_id_shape.dims(); index < input_shape.dims();
++index) {
output_shape.AddDim(input_shape.dim_size(index));
}
return output_shape;
}
} // namespace
template <typename INDICES_TYPE, typename NUM_SEGMENTS_TYPE>
class UnsortedSegmentJoinOp : public OpKernel {
public:
using OpKernel::OpKernel;
explicit UnsortedSegmentJoinOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("separator", &separator_));
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
const int32 input_dims = input_shape.dims();
const Tensor& segment_id = context->input(1);
const TensorShape& segment_id_shape = segment_id.shape();
const int32 segment_dims = segment_id_shape.dims();
const Tensor& num_segments_tensor = context->input(2);
auto num_segments = num_segments_tensor.scalar<NUM_SEGMENTS_TYPE>()();
OP_REQUIRES(context, segment_dims != 0,
errors::InvalidArgument("Segment_id cannot have rank 0"));
OP_REQUIRES(
context, segment_dims <= input_dims,
errors::OutOfRange("Invalid segment_id rank ", segment_dims,
" for input with ", input_dims, " dimension(s)"));
for (auto i = 0; i < segment_dims; i++) {
OP_REQUIRES(
context, segment_id_shape.dim_size(i) == input_shape.dim_size(i),
errors::InvalidArgument(
"Segment dimension is ", segment_id_shape.dim_size(i),
" while input dimension is ", input_dims, " in rank ", i));
}
// Making output tensor.
Tensor* output_tensor = nullptr;
TensorShape output_shape =
GetOutputShape(input_shape, segment_id_shape, num_segments);
OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
&output_tensor));
// Preprating flat tensors.
auto output_flat = output_tensor->flat<string>();
auto flat_segment_id = segment_id.flat<INDICES_TYPE>();
auto flat_input = input.flat<string>();
for (int i = 0; i < flat_segment_id.size(); i++) {
OP_REQUIRES(
context,
((flat_segment_id(i) < num_segments) && (flat_segment_id(i) >= 0)),
errors::InvalidArgument(
"segment_ids are not allowed to exceed num_segments or"
" to have negative values."));
}
int64 big_stride;
int64 small_stride;
std::tie(big_stride, small_stride) =
GetStrides<INDICES_TYPE>(input_shape, segment_id_shape);
auto relative_offset_set =
GetFlattenedRelativeOffsets<INDICES_TYPE>(small_stride, big_stride);
for (auto start_offset = 0; start_offset < big_stride; start_offset++) {
for (auto i = 0; i < relative_offset_set.size(); i++) {
auto output_index = start_offset + flat_segment_id(i) * big_stride;
auto offset = start_offset + relative_offset_set[i];
if (output_flat(output_index).length() != 0)
output_flat(output_index).append(separator_.c_str());
output_flat(output_index).append(flat_input(offset));
}
}
}
private:
string separator_;
};
#define REGISTER_CPU_KERNEL(indices_type, num_segments_type) \
REGISTER_KERNEL_BUILDER( \
Name("UnsortedSegmentJoin") \
.Device(DEVICE_CPU) \
.TypeConstraint<indices_type>("Tindices") \
.TypeConstraint<num_segments_type>("Tnumsegments"), \
UnsortedSegmentJoinOp<indices_type, num_segments_type>);
REGISTER_CPU_KERNEL(int32, int32);
REGISTER_CPU_KERNEL(int32, int64);
REGISTER_CPU_KERNEL(int64, int32);
REGISTER_CPU_KERNEL(int64, int64);
#undef REGISTER_CPU_KERNEL
} // namespace tensorflow

View File

@ -1176,37 +1176,6 @@ Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
c->set_output(0, out); c->set_output(0, out);
return Status::OK(); return Status::OK();
} }
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
ShapeHandle s_data = c->input(0);
ShapeHandle s_segment_ids = c->input(1);
ShapeHandle s_num_segments = c->input(2);
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
ShapeHandle out;
// Leading dimensions of data must be compatible with dimensions of
// <s_segment_ids>.
if (c->RankKnown(s_segment_ids)) {
TF_RETURN_IF_ERROR(
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
// Get the value of the num_segments input tensor.
DimensionHandle num_segments_dim;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
// Output is {segment_id_rank} + s_data[segment_id_rank:].
ShapeHandle s_data_suffix;
TF_RETURN_IF_ERROR(
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
TF_RETURN_IF_ERROR(
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
} else {
out = c->UnknownShape();
}
c->set_output(0, out);
return Status::OK();
}
} // namespace } // namespace
REGISTER_OP("SegmentSum") REGISTER_OP("SegmentSum")
@ -1257,7 +1226,7 @@ REGISTER_OP("UnsortedSegmentSum")
.Attr("T: numbertype") .Attr("T: numbertype")
.Attr("Tindices: {int32,int64}") .Attr("Tindices: {int32,int64}")
.Attr("Tnumsegments: {int32,int64} = DT_INT32") .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn); .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
REGISTER_OP("UnsortedSegmentMax") REGISTER_OP("UnsortedSegmentMax")
.Input("data: T") .Input("data: T")
@ -1267,7 +1236,7 @@ REGISTER_OP("UnsortedSegmentMax")
.Attr("T: realnumbertype") .Attr("T: realnumbertype")
.Attr("Tindices: {int32,int64}") .Attr("Tindices: {int32,int64}")
.Attr("Tnumsegments: {int32,int64} = DT_INT32") .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn); .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
REGISTER_OP("UnsortedSegmentMin") REGISTER_OP("UnsortedSegmentMin")
.Input("data: T") .Input("data: T")
@ -1277,7 +1246,7 @@ REGISTER_OP("UnsortedSegmentMin")
.Attr("T: realnumbertype") .Attr("T: realnumbertype")
.Attr("Tindices: {int32,int64}") .Attr("Tindices: {int32,int64}")
.Attr("Tnumsegments: {int32,int64} = DT_INT32") .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn); .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
REGISTER_OP("UnsortedSegmentProd") REGISTER_OP("UnsortedSegmentProd")
.Input("data: T") .Input("data: T")
@ -1287,7 +1256,7 @@ REGISTER_OP("UnsortedSegmentProd")
.Attr("T: numbertype") .Attr("T: numbertype")
.Attr("Tindices: {int32,int64}") .Attr("Tindices: {int32,int64}")
.Attr("Tnumsegments: {int32,int64} = DT_INT32") .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn); .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
REGISTER_OP("SparseSegmentSum") REGISTER_OP("SparseSegmentSum")
.Input("data: T") .Input("data: T")

View File

@ -101,6 +101,16 @@ REGISTER_OP("ReduceJoin")
.Output("output: string") .Output("output: string")
.SetShapeFn(shape_inference::ReductionShape); .SetShapeFn(shape_inference::ReductionShape);
REGISTER_OP("UnsortedSegmentJoin")
.Input("inputs: string")
.Input("segment_ids: Tindices")
.Input("num_segments: Tnumsegments")
.Attr("separator: string = ''")
.Attr("Tindices: {int32,int64}")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Output("output: string")
.SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
REGISTER_OP("AsString") REGISTER_OP("AsString")
.Input("input: T") .Input("input: T")
.Output("output: string") .Output("output: string")

View File

@ -2308,6 +2308,20 @@ cuda_py_test(
xla_enable_strict_auto_jit = True, xla_enable_strict_auto_jit = True,
) )
cuda_py_test(
name = "unsorted_segment_join_op_test",
size = "small",
srcs = ["unsorted_segment_join_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:string_ops",
],
xla_enable_strict_auto_jit = True,
)
cuda_py_test( cuda_py_test(
name = "reduction_ops_test", name = "reduction_ops_test",
size = "medium", size = "medium",

View File

@ -0,0 +1,307 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for unsorted_segment_join_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class UnicodeTestCase(test.TestCase):
"""Test case with Python3-compatible string comparator."""
def assertAllEqualUnicode(self, truth, actual):
self.assertAllEqual(
np.array(truth).astype('U'),
np.array(actual).astype('U'))
@test_util.run_all_in_graph_and_eager_modes
class UnsortedSegmentJoinOpTest(UnicodeTestCase, parameterized.TestCase):
def test_basic_np_array(self):
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
segment_ids = [1, 0, 1]
num_segments = 2
separator = ':'
output_array = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
self.assertAllEqual(res.shape, np.array(output_array).shape)
self.assertAllEqualUnicode(res, output_array)
def test_segment_id_and_input_empty(self):
inputs = np.array([], dtype=np.string_)
segment_ids = np.array([], dtype=np.int32)
num_segments = 3
separator = ':'
output_array = ['', '', '']
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
self.assertAllEqual(res.shape, np.array(output_array).shape)
self.assertAllEqualUnicode(res, output_array)
def test_type_check(self):
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
segment_ids = np.array([1, 0, 1], dtype=np.int32)
num_segments = np.array(2, dtype=np.int32)
separator = ':'
output_array = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
self.assertAllEqual(res.shape, np.array(output_array).shape)
self.assertAllEqualUnicode(res, output_array)
segment_ids = np.array([1, 0, 1], dtype=np.int64)
num_segments = np.array(2, dtype=np.int64)
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
self.assertAllEqual(res.shape, np.array(output_array).shape)
self.assertAllEqualUnicode(res, output_array)
def test_basic_tensor(self):
inputs = constant_op.constant([['Y', 'q', 'c'], ['Y', '6', '6'],
['p', 'G', 'a']])
segment_ids = constant_op.constant([1, 0, 1])
num_segments = 2
separator = ':'
output_array = constant_op.constant([['Y', '6', '6'], ['Y:p', 'q:G',
'c:a']])
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
self.assertAllEqual(res, output_array)
self.assertAllEqual(res.shape, output_array.get_shape())
def test_multiple_segment_join(self):
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
segment_ids_1 = [1, 0, 1]
num_segments_1 = 2
separator_1 = ':'
output_array_1 = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids_1,
num_segments=num_segments_1,
separator=separator_1))
self.assertAllEqualUnicode(res, output_array_1)
self.assertAllEqual(res.shape, np.array(output_array_1).shape)
segment_ids_2 = [1, 1]
num_segments_2 = 2
separator_2 = ''
output_array_2 = [['', '', ''], ['YY:p', '6q:G', '6c:a']]
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=res,
segment_ids=segment_ids_2,
num_segments=num_segments_2,
separator=separator_2))
self.assertAllEqualUnicode(res, output_array_2)
self.assertAllEqual(res.shape, np.array(output_array_2).shape)
@parameterized.parameters([
{
'inputs': [[[['q'], ['s']], [['f'], ['F']], [['h'], ['0']]],
[[['E'], ['j']], [['2'], ['k']], [['N'], ['d']]],
[[['G'], ['M']], [['1'], ['S']], [['N'], ['7']]],
[[['8'], ['W']], [['W'], ['G']], [['j'], ['d']]]],
'segment_ids': [1, 1, 0, 2],
'num_segments':
3,
'separator':
':',
'output_array': [[[['G'], ['M']], [['1'], ['S']], [['N'], ['7']]],
[[['q:E'], ['s:j']], [['f:2'], ['F:k']],
[['h:N'], ['0:d']]],
[[['8'], ['W']], [['W'], ['G']], [['j'], ['d']]]],
},
{
'inputs': [[['Q', 'b'], ['c', 'p']], [['i', '9'], ['n', 'b']],
[['T', 'h'], ['g', 'z']]],
'segment_ids': [[0, 1], [1, 0], [1, 0]],
'num_segments': 2,
'separator': ':',
'output_array': [['Q:n:g', 'b:b:z'], ['c:i:T', 'p:9:h']]
},
{
'inputs': [[['Q', 'b'], ['b', 'p']], [['i', '9'], ['n', 'b']],
[['T', 'h'], ['g', 'z']]],
'segment_ids': [[[2, 1], [0, 0]], [[2, 0], [2, 2]], [[0, 2], [1, 0]]],
'num_segments': 3,
'separator': ':',
'output_array': ['b:p:9:T:z', 'b:g', 'Q:i:n:b:h']
},
{
'inputs': [[['z'], ['h']], [['c'], ['z']], [['V'], ['T']]],
'segment_ids': [0, 1, 1],
'num_segments': 3,
'separator': ':',
'output_array': [[['z'], ['h']], [['c:V'], ['z:T']], [[''], ['']]]
},
])
def test_multiple_cases_with_different_dims(self, inputs, segment_ids,
num_segments, separator,
output_array):
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
self.assertAllEqualUnicode(res, output_array)
self.assertAllEqual(res.shape, np.array(output_array).shape)
@parameterized.parameters([
{
'separator': '',
'output_array': ['thisisatest']
},
{
'separator': ':',
'output_array': ['this:is:a:test']
},
{
'separator': 'UNK',
'output_array': ['thisUNKisUNKaUNKtest']
},
])
def testSeparator(self, separator, output_array):
inputs = ['this', 'is', 'a', 'test']
segment_ids = [0, 0, 0, 0]
num_segments = 1
res = self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
self.assertAllEqual(res.shape, np.array(output_array).shape)
self.assertAllEqualUnicode(res, output_array)
def test_fail_segment_id_exceeds_segment_nums(self):
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
segment_ids = [1, 0, 1]
num_segments = 1
separator = ':'
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
def test_fail_segment_id_dim_does_not_match(self):
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
segment_ids = [1, 0, 1, 1]
num_segments = 2
separator = ':'
if not context.executing_eagerly():
with self.assertRaises(ValueError):
self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
else:
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
def test_fail_segment_id_empty_input_non_empty(self):
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
segment_ids = np.array([], dtype=np.int32)
num_segments = 2
separator = ':'
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
def test_empty_input(self):
inputs = np.array([], dtype=np.string_)
segment_ids = [1, 0, 1]
num_segments = 2
separator = ':'
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
def test_fail_negative_segment_id(self):
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
segment_ids = [-1, 0, -1]
num_segments = 1
separator = ':'
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(
string_ops.unsorted_segment_join(
inputs=inputs,
segment_ids=segment_ids,
num_segments=num_segments,
separator=separator))
if __name__ == '__main__':
test.main()

View File

@ -4392,6 +4392,10 @@ tf_module {
name: "UnravelIndex" name: "UnravelIndex"
argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "UnsortedSegmentJoin"
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method { member_method {
name: "UnsortedSegmentMax" name: "UnsortedSegmentMax"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -92,6 +92,10 @@ tf_module {
name: "unicode_transcode" name: "unicode_transcode"
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], " argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
} }
member_method {
name: "unsorted_segment_join"
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method { member_method {
name: "upper" name: "upper"
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "

View File

@ -4392,6 +4392,10 @@ tf_module {
name: "UnravelIndex" name: "UnravelIndex"
argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "UnsortedSegmentJoin"
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method { member_method {
name: "UnsortedSegmentMax" name: "UnsortedSegmentMax"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -92,6 +92,10 @@ tf_module {
name: "unicode_transcode" name: "unicode_transcode"
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], " argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
} }
member_method {
name: "unsorted_segment_join"
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method { member_method {
name: "upper" name: "upper"
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "