Yunxing Dai d45ca04c30 Add a fast-path in tf.unique where only the first output is used.
The fast path uses native sort ops which is ~1000x faster than the while version. We rely on DCE to trim the slow version if second output is not used.

PiperOrigin-RevId: 348683799
Change-Id: Iad6e297c71f82f2c61b87a4cddc7b27d62e5ae7a
2020-12-22 13:27:31 -08:00

259 lines
12 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace {
class UniqueOp : public XlaOpKernel {
public:
explicit UniqueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// We use a two level loop algorithm to calculate unique.
//
// i = 0
// output_size = 0
// output_indices = broadcast(0, {input_size})
// while (i < input_size) {
// search_result_index = output_size
// j = 0
// while (j < output_size) {
// if(input[j]==input[i]) {
// search_result_index = j
// }
// ++j
// }
// input[search_result_index] = input[i]
// output_indices[i] = search_result_index
// if (search_result_index == output_size) {
// // Not found
// output_size ++;
// }
// i ++;
// }
//
// The algorithm is then functionalized into xla whiles. Outer-scoped
// variables are captured as inputs and outputs to the while loop.
// Conditionals are rewritten into xla select for simplicity.
xla::XlaComputation BuildInnerLoopCond(XlaOpKernelContext* ctx,
xla::Shape inner_loop_shape) {
std::unique_ptr<xla::XlaBuilder> builder =
ctx->builder()->CreateSubBuilder("inner_loop_cond");
auto param = xla::Parameter(builder.get(), 0, inner_loop_shape, "param");
auto j = xla::GetTupleElement(param, 2);
auto output_element_size = xla::GetTupleElement(param, 3);
xla::Lt(j, output_element_size);
return builder->Build().ConsumeValueOrDie();
}
xla::XlaComputation BuildInnerLoopBody(XlaOpKernelContext* ctx,
xla::Shape inner_loop_shape,
xla::Shape single_element_shape) {
std::unique_ptr<xla::XlaBuilder> builder =
ctx->builder()->CreateSubBuilder("inner_loop_body");
auto param = xla::Parameter(builder.get(), 0, inner_loop_shape, "param");
auto input = xla::GetTupleElement(param, 0);
auto target = xla::GetTupleElement(param, 1);
auto j = xla::GetTupleElement(param, 2);
auto output_element_size = xla::GetTupleElement(param, 3);
auto output_index = xla::GetTupleElement(param, 4);
auto input_elem = xla::DynamicSlice(input, {j}, {1});
auto input_elem_scalar = xla::Reshape(single_element_shape, input_elem);
auto eq = xla::Eq(input_elem_scalar, target);
auto select = xla::Select(eq, j, output_index);
auto next_j = xla::Add(j, xla::One(builder.get(), xla::S32));
xla::Tuple(builder.get(),
{input, target, next_j, output_element_size, select});
return builder->Build().ConsumeValueOrDie();
}
xla::XlaComputation BuildOuterLoopCond(XlaOpKernelContext* ctx,
xla::Shape outer_loop_shape,
int64 list_size) {
std::unique_ptr<xla::XlaBuilder> builder =
ctx->builder()->CreateSubBuilder("outer_loop_body");
auto param =
xla::Parameter(builder.get(), 0, outer_loop_shape, "outer_loop_param");
auto i = xla::GetTupleElement(param, 2);
auto bound = xla::ConstantR0<int32>(builder.get(), list_size);
xla::Lt(i, bound);
return builder->Build().ConsumeValueOrDie();
}
xla::XlaComputation BuildOuterLoopBody(
XlaOpKernelContext* ctx, xla::Shape outer_loop_shape,
xla::Shape single_element_shape, const xla::XlaComputation& inner_cond,
const xla::XlaComputation& inner_body) {
std::unique_ptr<xla::XlaBuilder> builder =
ctx->builder()->CreateSubBuilder("outer_loop_body");
auto param = xla::Parameter(builder.get(), 0, outer_loop_shape, "param");
auto input = xla::GetTupleElement(param, 0);
auto indices = xla::GetTupleElement(param, 1);
auto i = xla::GetTupleElement(param, 2);
auto output_element_size = xla::GetTupleElement(param, 3);
auto zero = xla::Zero(builder.get(), xla::S32);
auto target = xla::DynamicSlice(input, {i}, {1});
auto target_scalar = xla::Reshape(single_element_shape, target);
auto inner_loop_param = xla::Tuple(
builder.get(),
{input, target_scalar, zero, output_element_size, output_element_size});
auto inner_loop = xla::While(inner_cond, inner_body, inner_loop_param);
auto output_index = xla::GetTupleElement(inner_loop, 4);
auto one = xla::One(builder.get(), xla::S32);
auto update_output_element_size =
xla::Select(xla::Eq(output_index, output_element_size),
xla::Add(output_element_size, one), output_element_size);
auto update_input = xla::DynamicUpdateSlice(input, target, {output_index});
auto update_indices =
xla::DynamicUpdateSlice(indices, xla::Reshape(output_index, {1}), {i});
xla::Tuple(builder.get(), {update_input, update_indices, xla::Add(i, one),
update_output_element_size});
return builder->Build().ConsumeValueOrDie();
}
xla::XlaOp DataOutputFastPath(XlaOpKernelContext* ctx, xla::XlaOp input,
const xla::Shape& input_shape) {
// Generate data output using a sub-quadratic sorting algorithm. If only the
// data output is used (meaning the indices output is ignored, which is
// common), DCE will only keep the fastpath.
auto iota_shape = input_shape;
iota_shape.set_element_type(xla::S32);
int64 input_count = input_shape.dimensions(0);
xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, 0);
std::vector<xla::XlaOp> to_sort = {input, iota};
std::vector<xla::PrimitiveType> types_to_sort = {input_shape.element_type(),
xla::S32};
xla::XlaOp sorted = xla::Sort(
to_sort, xla::CreateScalarLtComputation(types_to_sort, ctx->builder()),
/*dimension=*/0,
/*is_stable=*/true);
xla::XlaOp sorted_data = xla::GetTupleElement(sorted, 0);
xla::XlaOp sorted_iota = xla::GetTupleElement(sorted, 1);
// Calculate unique_indices, unique_indices[i] is true when sorted_data[i]
// != sorted_data[i - 1]. unique_indices[0] is always true.
// We do this by shifting sorted_data by 1 position and then compare it with
// itself.
// A[0:n-1]
auto shifted = xla::SliceInDim(sorted_data, 0, input_count - 1,
/*stride=*/1, /*dimno=*/0);
auto zero =
xla::Zeros(ctx->builder(),
xla::ShapeUtil::MakeShape(input_shape.element_type(), {1}));
// shifted = concat(0, A[0:n-1]),
shifted = xla::ConcatInDim(ctx->builder(), {zero, shifted}, 0);
xla::XlaOp unique_indices = xla::Ne(shifted, sorted_data);
xla::XlaOp true_r1 = xla::One(ctx->builder(), xla::PRED);
true_r1 = xla::Reshape(true_r1, {1});
// First element is always unique(true).
unique_indices = xla::DynamicUpdateSlice(
unique_indices, true_r1, {xla::Zero(ctx->builder(), xla::S32)});
unique_indices = xla::ConvertElementType(unique_indices, xla::S32);
// Do a reverse sort using iota as key, this moves `changed` to its original
// positions in input space.
std::vector<xla::XlaOp> to_sort_reverse = {sorted_iota, unique_indices};
std::vector<xla::PrimitiveType> types_to_sort_reverse = {xla::S32,
xla::S32};
xla::XlaOp sort_reverse = xla::Sort(
to_sort_reverse,
xla::CreateScalarLtComputation(types_to_sort_reverse, ctx->builder()),
/*dimension=*/0,
/*is_stable=*/true);
// Now unique_indices are scattered back to original positions.
unique_indices = xla::GetTupleElement(sort_reverse, 1);
// Do a final sort to move unique items together.
std::vector<xla::XlaOp> to_sort_final = {unique_indices, input};
std::vector<xla::PrimitiveType> types_to_sort_final = {
xla::S32, input_shape.element_type()};
xla::XlaOp final_sort = xla::Sort(
to_sort_final,
xla::CreateScalarGtComputation(types_to_sort_final, ctx->builder()),
/*dimension=*/0,
/*is_stable=*/true);
xla::XlaOp output = xla::GetTupleElement(final_sort, 1);
xla::XlaOp unique_count = xla::ReduceAll(
unique_indices, xla::Zero(ctx->builder(), xla::S32),
xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
xla::XlaOp output_dynamic = xla::SetDimensionSize(output, unique_count, 0);
return output_dynamic;
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
xla::StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, input_shape_or.status());
auto input_shape = input_shape_or.ValueOrDie();
ctx->SetOutput(0, DataOutputFastPath(ctx, input, input_shape));
// Slow path to generate indices.
xla::Shape single_index_shape = xla::ShapeUtil::MakeScalarShape(xla::S32);
xla::Shape single_element_shape =
xla::ShapeUtil::MakeScalarShape(input_shape.element_type());
OP_REQUIRES(ctx, input_shape.rank() == 1,
xla::InvalidArgument("Input to UniqueOp must be rank-1: %s",
input_shape.ToString()));
int64 list_size = input_shape.dimensions()[0];
auto indices_shape =
xla::ShapeUtil::ChangeElementType(input_shape, xla::S32);
auto outer_loop_shape = xla::ShapeUtil::MakeTupleShape(
{input_shape, indices_shape, single_index_shape, single_index_shape});
auto inner_loop_shape = xla::ShapeUtil::MakeTupleShape(
{input_shape, single_element_shape, single_index_shape,
single_index_shape, single_index_shape});
xla::XlaComputation inner_loop_cond =
BuildInnerLoopCond(ctx, inner_loop_shape);
xla::XlaComputation inner_loop_body =
BuildInnerLoopBody(ctx, inner_loop_shape, single_element_shape);
xla::XlaComputation outer_loop_cond =
BuildOuterLoopCond(ctx, outer_loop_shape, list_size);
xla::XlaComputation outer_loop_body =
BuildOuterLoopBody(ctx, outer_loop_shape, single_element_shape,
inner_loop_cond, inner_loop_body);
auto zero = xla::Zero(ctx->builder(), xla::S32);
auto init_indices = xla::Broadcast(zero, {list_size});
auto init = xla::Tuple(ctx->builder(), {input, init_indices, zero, zero});
auto outer_while = xla::While(outer_loop_cond, outer_loop_body, init);
auto output_indices = xla::GetTupleElement(outer_while, 1);
ctx->SetOutput(1, output_indices);
}
};
REGISTER_XLA_OP(Name("Unique").Device(DEVICE_TPU_XLA_JIT), UniqueOp);
} // namespace
} // namespace tensorflow