Added new op: tf.strings.unsorted_segment_join.
PiperOrigin-RevId: 251690479
This commit is contained in:
parent
bb5ae484d0
commit
f0f318aa40
@ -0,0 +1,62 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "UnsortedSegmentJoin"
|
||||||
|
in_arg {
|
||||||
|
name: "inputs"
|
||||||
|
description: <<END
|
||||||
|
The input to be joined.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "segment_ids"
|
||||||
|
description: <<END
|
||||||
|
A tensor whose shape is a prefix of data.shape. Negative segment ids are not
|
||||||
|
supported.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "num_segments"
|
||||||
|
description: <<END
|
||||||
|
A scalar.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "separator"
|
||||||
|
description: <<END
|
||||||
|
The separator to use when joining.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Joins the elements of `inputs` based on `segment_ids`."
|
||||||
|
description: <<END
|
||||||
|
Computes the string join along segments of a tensor.
|
||||||
|
Given `segment_ids` with rank `N` and `data` with rank `N+M`:
|
||||||
|
|
||||||
|
`output[i, k1...kM] = strings.join([data[j1...jN, k1...kM])`
|
||||||
|
|
||||||
|
where the join is over all [j1...jN] such that segment_ids[j1...jN] = i.
|
||||||
|
Strings are joined in row-major order.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
output_array = string_ops.unsorted_segment_join(inputs=inputs,
|
||||||
|
segment_ids=[1, 0, 1],
|
||||||
|
num_segments=2,
|
||||||
|
separator=':'))
|
||||||
|
# output_array ==> [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
|
||||||
|
|
||||||
|
|
||||||
|
inputs = ['this', 'is', 'a', 'test']
|
||||||
|
output_array = string_ops.unsorted_segment_join(inputs=inputs,
|
||||||
|
segment_ids=[0, 0, 0, 0],
|
||||||
|
num_segments=1,
|
||||||
|
separator=':'))
|
||||||
|
# output_array ==> ['this:is:a:test']
|
||||||
|
```
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,6 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "UnsortedSegmentJoin"
|
||||||
|
endpoint {
|
||||||
|
name: "strings.unsorted_segment_join"
|
||||||
|
}
|
||||||
|
}
|
@ -1443,6 +1443,37 @@ Status RandomShape(shape_inference::InferenceContext* c) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
|
||||||
|
ShapeHandle s_data = c->input(0);
|
||||||
|
ShapeHandle s_segment_ids = c->input(1);
|
||||||
|
ShapeHandle s_num_segments = c->input(2);
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
|
||||||
|
|
||||||
|
ShapeHandle out;
|
||||||
|
|
||||||
|
// Leading dimensions of data must be compatible with dimensions of
|
||||||
|
// <s_segment_ids>.
|
||||||
|
if (c->RankKnown(s_segment_ids)) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
|
||||||
|
|
||||||
|
// Get the value of the num_segments input tensor.
|
||||||
|
DimensionHandle num_segments_dim;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
|
||||||
|
|
||||||
|
// Output is {segment_id_rank} + s_data[segment_id_rank:].
|
||||||
|
ShapeHandle s_data_suffix;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
|
||||||
|
} else {
|
||||||
|
out = c->UnknownShape();
|
||||||
|
}
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// This SliceHelper processes the output shape of the `slice`
|
// This SliceHelper processes the output shape of the `slice`
|
||||||
|
@ -276,6 +276,9 @@ Status UnknownShape(shape_inference::InferenceContext* c);
|
|||||||
// Shape function for reduction operations.
|
// Shape function for reduction operations.
|
||||||
Status ReductionShape(shape_inference::InferenceContext* c);
|
Status ReductionShape(shape_inference::InferenceContext* c);
|
||||||
|
|
||||||
|
// Shape function for unsorted segment operations.
|
||||||
|
Status UnsortedSegmentReductionShapeFn(InferenceContext* c);
|
||||||
|
|
||||||
// Shape function for concat operations.
|
// Shape function for concat operations.
|
||||||
// <num_inputs_to_concat> is the number of inputs to concatenate and are taken
|
// <num_inputs_to_concat> is the number of inputs to concatenate and are taken
|
||||||
// from inputs
|
// from inputs
|
||||||
|
@ -5183,6 +5183,7 @@ cc_library(
|
|||||||
":substr_op",
|
":substr_op",
|
||||||
":unicode_ops",
|
":unicode_ops",
|
||||||
":unicode_script_op",
|
":unicode_script_op",
|
||||||
|
":unsorted_segment_join_op",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -5219,6 +5220,12 @@ tf_kernel_library(
|
|||||||
deps = STRING_DEPS,
|
deps = STRING_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "unsorted_segment_join_op",
|
||||||
|
prefix = "unsorted_segment_join_op",
|
||||||
|
deps = STRING_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "string_format_op",
|
name = "string_format_op",
|
||||||
prefix = "string_format_op",
|
prefix = "string_format_op",
|
||||||
|
166
tensorflow/core/kernels/unsorted_segment_join_op.cc
Normal file
166
tensorflow/core/kernels/unsorted_segment_join_op.cc
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/string_ops.cc.
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename INDICES_TYPE>
|
||||||
|
gtl::InlinedVector<INDICES_TYPE, 8> GetFlattenedRelativeOffsets(
|
||||||
|
INDICES_TYPE small_stride, INDICES_TYPE big_stride) {
|
||||||
|
gtl::InlinedVector<INDICES_TYPE, 8> flattened_offsets(small_stride);
|
||||||
|
for (auto i = 0; i < small_stride; i++) {
|
||||||
|
flattened_offsets[i] = i * big_stride;
|
||||||
|
}
|
||||||
|
return flattened_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename INDICES_TYPE>
|
||||||
|
std::pair<INDICES_TYPE, INDICES_TYPE> GetStrides(
|
||||||
|
const TensorShape& input_shape, const TensorShape& segment_id_shape) {
|
||||||
|
int64 small_stride = 1;
|
||||||
|
int64 big_stride = 1;
|
||||||
|
for (auto i = 0; i < input_shape.dims(); i++) {
|
||||||
|
if (i < segment_id_shape.dims()) {
|
||||||
|
small_stride *= segment_id_shape.dim_size(i);
|
||||||
|
} else {
|
||||||
|
big_stride *= input_shape.dim_size(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(big_stride, small_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape GetOutputShape(const TensorShape& input_shape,
|
||||||
|
const TensorShape& segment_id_shape,
|
||||||
|
const int64 num_segments) {
|
||||||
|
TensorShape output_shape;
|
||||||
|
output_shape.AddDim(num_segments);
|
||||||
|
for (size_t index = segment_id_shape.dims(); index < input_shape.dims();
|
||||||
|
++index) {
|
||||||
|
output_shape.AddDim(input_shape.dim_size(index));
|
||||||
|
}
|
||||||
|
return output_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename INDICES_TYPE, typename NUM_SEGMENTS_TYPE>
|
||||||
|
class UnsortedSegmentJoinOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
using OpKernel::OpKernel;
|
||||||
|
|
||||||
|
explicit UnsortedSegmentJoinOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("separator", &separator_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor& input = context->input(0);
|
||||||
|
const TensorShape& input_shape = input.shape();
|
||||||
|
const int32 input_dims = input_shape.dims();
|
||||||
|
|
||||||
|
const Tensor& segment_id = context->input(1);
|
||||||
|
const TensorShape& segment_id_shape = segment_id.shape();
|
||||||
|
const int32 segment_dims = segment_id_shape.dims();
|
||||||
|
|
||||||
|
const Tensor& num_segments_tensor = context->input(2);
|
||||||
|
auto num_segments = num_segments_tensor.scalar<NUM_SEGMENTS_TYPE>()();
|
||||||
|
|
||||||
|
OP_REQUIRES(context, segment_dims != 0,
|
||||||
|
errors::InvalidArgument("Segment_id cannot have rank 0"));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, segment_dims <= input_dims,
|
||||||
|
errors::OutOfRange("Invalid segment_id rank ", segment_dims,
|
||||||
|
" for input with ", input_dims, " dimension(s)"));
|
||||||
|
for (auto i = 0; i < segment_dims; i++) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, segment_id_shape.dim_size(i) == input_shape.dim_size(i),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Segment dimension is ", segment_id_shape.dim_size(i),
|
||||||
|
" while input dimension is ", input_dims, " in rank ", i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Making output tensor.
|
||||||
|
Tensor* output_tensor = nullptr;
|
||||||
|
TensorShape output_shape =
|
||||||
|
GetOutputShape(input_shape, segment_id_shape, num_segments);
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
|
||||||
|
&output_tensor));
|
||||||
|
|
||||||
|
// Preprating flat tensors.
|
||||||
|
auto output_flat = output_tensor->flat<string>();
|
||||||
|
auto flat_segment_id = segment_id.flat<INDICES_TYPE>();
|
||||||
|
auto flat_input = input.flat<string>();
|
||||||
|
|
||||||
|
for (int i = 0; i < flat_segment_id.size(); i++) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
context,
|
||||||
|
((flat_segment_id(i) < num_segments) && (flat_segment_id(i) >= 0)),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"segment_ids are not allowed to exceed num_segments or"
|
||||||
|
" to have negative values."));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 big_stride;
|
||||||
|
int64 small_stride;
|
||||||
|
std::tie(big_stride, small_stride) =
|
||||||
|
GetStrides<INDICES_TYPE>(input_shape, segment_id_shape);
|
||||||
|
auto relative_offset_set =
|
||||||
|
GetFlattenedRelativeOffsets<INDICES_TYPE>(small_stride, big_stride);
|
||||||
|
for (auto start_offset = 0; start_offset < big_stride; start_offset++) {
|
||||||
|
for (auto i = 0; i < relative_offset_set.size(); i++) {
|
||||||
|
auto output_index = start_offset + flat_segment_id(i) * big_stride;
|
||||||
|
auto offset = start_offset + relative_offset_set[i];
|
||||||
|
if (output_flat(output_index).length() != 0)
|
||||||
|
output_flat(output_index).append(separator_.c_str());
|
||||||
|
output_flat(output_index).append(flat_input(offset));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
string separator_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_CPU_KERNEL(indices_type, num_segments_type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("UnsortedSegmentJoin") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<indices_type>("Tindices") \
|
||||||
|
.TypeConstraint<num_segments_type>("Tnumsegments"), \
|
||||||
|
UnsortedSegmentJoinOp<indices_type, num_segments_type>);
|
||||||
|
|
||||||
|
REGISTER_CPU_KERNEL(int32, int32);
|
||||||
|
REGISTER_CPU_KERNEL(int32, int64);
|
||||||
|
REGISTER_CPU_KERNEL(int64, int32);
|
||||||
|
REGISTER_CPU_KERNEL(int64, int64);
|
||||||
|
#undef REGISTER_CPU_KERNEL
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -1176,37 +1176,6 @@ Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
|
|||||||
c->set_output(0, out);
|
c->set_output(0, out);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
|
|
||||||
ShapeHandle s_data = c->input(0);
|
|
||||||
ShapeHandle s_segment_ids = c->input(1);
|
|
||||||
ShapeHandle s_num_segments = c->input(2);
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
|
|
||||||
|
|
||||||
ShapeHandle out;
|
|
||||||
|
|
||||||
// Leading dimensions of data must be compatible with dimensions of
|
|
||||||
// <s_segment_ids>.
|
|
||||||
if (c->RankKnown(s_segment_ids)) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
|
|
||||||
|
|
||||||
// Get the value of the num_segments input tensor.
|
|
||||||
DimensionHandle num_segments_dim;
|
|
||||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
|
|
||||||
|
|
||||||
// Output is {segment_id_rank} + s_data[segment_id_rank:].
|
|
||||||
ShapeHandle s_data_suffix;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
|
|
||||||
} else {
|
|
||||||
out = c->UnknownShape();
|
|
||||||
}
|
|
||||||
c->set_output(0, out);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
REGISTER_OP("SegmentSum")
|
REGISTER_OP("SegmentSum")
|
||||||
@ -1257,7 +1226,7 @@ REGISTER_OP("UnsortedSegmentSum")
|
|||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("Tindices: {int32,int64}")
|
.Attr("Tindices: {int32,int64}")
|
||||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||||
.SetShapeFn(UnsortedSegmentReductionShapeFn);
|
.SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("UnsortedSegmentMax")
|
REGISTER_OP("UnsortedSegmentMax")
|
||||||
.Input("data: T")
|
.Input("data: T")
|
||||||
@ -1267,7 +1236,7 @@ REGISTER_OP("UnsortedSegmentMax")
|
|||||||
.Attr("T: realnumbertype")
|
.Attr("T: realnumbertype")
|
||||||
.Attr("Tindices: {int32,int64}")
|
.Attr("Tindices: {int32,int64}")
|
||||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||||
.SetShapeFn(UnsortedSegmentReductionShapeFn);
|
.SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("UnsortedSegmentMin")
|
REGISTER_OP("UnsortedSegmentMin")
|
||||||
.Input("data: T")
|
.Input("data: T")
|
||||||
@ -1277,7 +1246,7 @@ REGISTER_OP("UnsortedSegmentMin")
|
|||||||
.Attr("T: realnumbertype")
|
.Attr("T: realnumbertype")
|
||||||
.Attr("Tindices: {int32,int64}")
|
.Attr("Tindices: {int32,int64}")
|
||||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||||
.SetShapeFn(UnsortedSegmentReductionShapeFn);
|
.SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("UnsortedSegmentProd")
|
REGISTER_OP("UnsortedSegmentProd")
|
||||||
.Input("data: T")
|
.Input("data: T")
|
||||||
@ -1287,7 +1256,7 @@ REGISTER_OP("UnsortedSegmentProd")
|
|||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("Tindices: {int32,int64}")
|
.Attr("Tindices: {int32,int64}")
|
||||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||||
.SetShapeFn(UnsortedSegmentReductionShapeFn);
|
.SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("SparseSegmentSum")
|
REGISTER_OP("SparseSegmentSum")
|
||||||
.Input("data: T")
|
.Input("data: T")
|
||||||
|
@ -101,6 +101,16 @@ REGISTER_OP("ReduceJoin")
|
|||||||
.Output("output: string")
|
.Output("output: string")
|
||||||
.SetShapeFn(shape_inference::ReductionShape);
|
.SetShapeFn(shape_inference::ReductionShape);
|
||||||
|
|
||||||
|
REGISTER_OP("UnsortedSegmentJoin")
|
||||||
|
.Input("inputs: string")
|
||||||
|
.Input("segment_ids: Tindices")
|
||||||
|
.Input("num_segments: Tnumsegments")
|
||||||
|
.Attr("separator: string = ''")
|
||||||
|
.Attr("Tindices: {int32,int64}")
|
||||||
|
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||||
|
.Output("output: string")
|
||||||
|
.SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("AsString")
|
REGISTER_OP("AsString")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Output("output: string")
|
.Output("output: string")
|
||||||
|
@ -2308,6 +2308,20 @@ cuda_py_test(
|
|||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "unsorted_segment_join_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["unsorted_segment_join_op_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:string_ops",
|
||||||
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "reduction_ops_test",
|
name = "reduction_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
307
tensorflow/python/kernel_tests/unsorted_segment_join_op_test.py
Normal file
307
tensorflow/python/kernel_tests/unsorted_segment_join_op_test.py
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for unsorted_segment_join_op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class UnicodeTestCase(test.TestCase):
|
||||||
|
"""Test case with Python3-compatible string comparator."""
|
||||||
|
|
||||||
|
def assertAllEqualUnicode(self, truth, actual):
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.array(truth).astype('U'),
|
||||||
|
np.array(actual).astype('U'))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class UnsortedSegmentJoinOpTest(UnicodeTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_basic_np_array(self):
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
segment_ids = [1, 0, 1]
|
||||||
|
num_segments = 2
|
||||||
|
separator = ':'
|
||||||
|
output_array = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
|
||||||
|
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array).shape)
|
||||||
|
self.assertAllEqualUnicode(res, output_array)
|
||||||
|
|
||||||
|
def test_segment_id_and_input_empty(self):
|
||||||
|
inputs = np.array([], dtype=np.string_)
|
||||||
|
segment_ids = np.array([], dtype=np.int32)
|
||||||
|
num_segments = 3
|
||||||
|
separator = ':'
|
||||||
|
output_array = ['', '', '']
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array).shape)
|
||||||
|
self.assertAllEqualUnicode(res, output_array)
|
||||||
|
|
||||||
|
def test_type_check(self):
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
segment_ids = np.array([1, 0, 1], dtype=np.int32)
|
||||||
|
num_segments = np.array(2, dtype=np.int32)
|
||||||
|
separator = ':'
|
||||||
|
output_array = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
|
||||||
|
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array).shape)
|
||||||
|
self.assertAllEqualUnicode(res, output_array)
|
||||||
|
|
||||||
|
segment_ids = np.array([1, 0, 1], dtype=np.int64)
|
||||||
|
num_segments = np.array(2, dtype=np.int64)
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array).shape)
|
||||||
|
self.assertAllEqualUnicode(res, output_array)
|
||||||
|
|
||||||
|
def test_basic_tensor(self):
|
||||||
|
inputs = constant_op.constant([['Y', 'q', 'c'], ['Y', '6', '6'],
|
||||||
|
['p', 'G', 'a']])
|
||||||
|
segment_ids = constant_op.constant([1, 0, 1])
|
||||||
|
num_segments = 2
|
||||||
|
separator = ':'
|
||||||
|
output_array = constant_op.constant([['Y', '6', '6'], ['Y:p', 'q:G',
|
||||||
|
'c:a']])
|
||||||
|
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
self.assertAllEqual(res, output_array)
|
||||||
|
self.assertAllEqual(res.shape, output_array.get_shape())
|
||||||
|
|
||||||
|
def test_multiple_segment_join(self):
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
segment_ids_1 = [1, 0, 1]
|
||||||
|
num_segments_1 = 2
|
||||||
|
separator_1 = ':'
|
||||||
|
output_array_1 = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]
|
||||||
|
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids_1,
|
||||||
|
num_segments=num_segments_1,
|
||||||
|
separator=separator_1))
|
||||||
|
self.assertAllEqualUnicode(res, output_array_1)
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array_1).shape)
|
||||||
|
|
||||||
|
segment_ids_2 = [1, 1]
|
||||||
|
num_segments_2 = 2
|
||||||
|
separator_2 = ''
|
||||||
|
output_array_2 = [['', '', ''], ['YY:p', '6q:G', '6c:a']]
|
||||||
|
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=res,
|
||||||
|
segment_ids=segment_ids_2,
|
||||||
|
num_segments=num_segments_2,
|
||||||
|
separator=separator_2))
|
||||||
|
self.assertAllEqualUnicode(res, output_array_2)
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array_2).shape)
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
{
|
||||||
|
'inputs': [[[['q'], ['s']], [['f'], ['F']], [['h'], ['0']]],
|
||||||
|
[[['E'], ['j']], [['2'], ['k']], [['N'], ['d']]],
|
||||||
|
[[['G'], ['M']], [['1'], ['S']], [['N'], ['7']]],
|
||||||
|
[[['8'], ['W']], [['W'], ['G']], [['j'], ['d']]]],
|
||||||
|
'segment_ids': [1, 1, 0, 2],
|
||||||
|
'num_segments':
|
||||||
|
3,
|
||||||
|
'separator':
|
||||||
|
':',
|
||||||
|
'output_array': [[[['G'], ['M']], [['1'], ['S']], [['N'], ['7']]],
|
||||||
|
[[['q:E'], ['s:j']], [['f:2'], ['F:k']],
|
||||||
|
[['h:N'], ['0:d']]],
|
||||||
|
[[['8'], ['W']], [['W'], ['G']], [['j'], ['d']]]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'inputs': [[['Q', 'b'], ['c', 'p']], [['i', '9'], ['n', 'b']],
|
||||||
|
[['T', 'h'], ['g', 'z']]],
|
||||||
|
'segment_ids': [[0, 1], [1, 0], [1, 0]],
|
||||||
|
'num_segments': 2,
|
||||||
|
'separator': ':',
|
||||||
|
'output_array': [['Q:n:g', 'b:b:z'], ['c:i:T', 'p:9:h']]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'inputs': [[['Q', 'b'], ['b', 'p']], [['i', '9'], ['n', 'b']],
|
||||||
|
[['T', 'h'], ['g', 'z']]],
|
||||||
|
'segment_ids': [[[2, 1], [0, 0]], [[2, 0], [2, 2]], [[0, 2], [1, 0]]],
|
||||||
|
'num_segments': 3,
|
||||||
|
'separator': ':',
|
||||||
|
'output_array': ['b:p:9:T:z', 'b:g', 'Q:i:n:b:h']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'inputs': [[['z'], ['h']], [['c'], ['z']], [['V'], ['T']]],
|
||||||
|
'segment_ids': [0, 1, 1],
|
||||||
|
'num_segments': 3,
|
||||||
|
'separator': ':',
|
||||||
|
'output_array': [[['z'], ['h']], [['c:V'], ['z:T']], [[''], ['']]]
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_multiple_cases_with_different_dims(self, inputs, segment_ids,
|
||||||
|
num_segments, separator,
|
||||||
|
output_array):
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
self.assertAllEqualUnicode(res, output_array)
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array).shape)
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
{
|
||||||
|
'separator': '',
|
||||||
|
'output_array': ['thisisatest']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'separator': ':',
|
||||||
|
'output_array': ['this:is:a:test']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'separator': 'UNK',
|
||||||
|
'output_array': ['thisUNKisUNKaUNKtest']
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def testSeparator(self, separator, output_array):
|
||||||
|
inputs = ['this', 'is', 'a', 'test']
|
||||||
|
segment_ids = [0, 0, 0, 0]
|
||||||
|
num_segments = 1
|
||||||
|
res = self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
self.assertAllEqual(res.shape, np.array(output_array).shape)
|
||||||
|
self.assertAllEqualUnicode(res, output_array)
|
||||||
|
|
||||||
|
def test_fail_segment_id_exceeds_segment_nums(self):
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
segment_ids = [1, 0, 1]
|
||||||
|
num_segments = 1
|
||||||
|
separator = ':'
|
||||||
|
|
||||||
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
|
self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
|
||||||
|
def test_fail_segment_id_dim_does_not_match(self):
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
segment_ids = [1, 0, 1, 1]
|
||||||
|
num_segments = 2
|
||||||
|
separator = ':'
|
||||||
|
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
else:
|
||||||
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
|
self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
|
||||||
|
def test_fail_segment_id_empty_input_non_empty(self):
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
segment_ids = np.array([], dtype=np.int32)
|
||||||
|
num_segments = 2
|
||||||
|
separator = ':'
|
||||||
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
|
self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
inputs = np.array([], dtype=np.string_)
|
||||||
|
segment_ids = [1, 0, 1]
|
||||||
|
num_segments = 2
|
||||||
|
separator = ':'
|
||||||
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
|
self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
|
||||||
|
def test_fail_negative_segment_id(self):
|
||||||
|
inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
|
||||||
|
segment_ids = [-1, 0, -1]
|
||||||
|
num_segments = 1
|
||||||
|
separator = ':'
|
||||||
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
|
self.evaluate(
|
||||||
|
string_ops.unsorted_segment_join(
|
||||||
|
inputs=inputs,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
num_segments=num_segments,
|
||||||
|
separator=separator))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -4392,6 +4392,10 @@ tf_module {
|
|||||||
name: "UnravelIndex"
|
name: "UnravelIndex"
|
||||||
argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "UnsortedSegmentJoin"
|
||||||
|
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "UnsortedSegmentMax"
|
name: "UnsortedSegmentMax"
|
||||||
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -92,6 +92,10 @@ tf_module {
|
|||||||
name: "unicode_transcode"
|
name: "unicode_transcode"
|
||||||
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unsorted_segment_join"
|
||||||
|
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "upper"
|
name: "upper"
|
||||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
@ -4392,6 +4392,10 @@ tf_module {
|
|||||||
name: "UnravelIndex"
|
name: "UnravelIndex"
|
||||||
argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "UnsortedSegmentJoin"
|
||||||
|
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "UnsortedSegmentMax"
|
name: "UnsortedSegmentMax"
|
||||||
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -92,6 +92,10 @@ tf_module {
|
|||||||
name: "unicode_transcode"
|
name: "unicode_transcode"
|
||||||
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unsorted_segment_join"
|
||||||
|
argspec: "args=[\'inputs\', \'segment_ids\', \'num_segments\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "upper"
|
name: "upper"
|
||||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user