[XLA] Add TopK rewriter pass

This pass pattern matches sort HLOs into a custom call. This will be useful for CPU.

PiperOrigin-RevId: 324976268
Change-Id: I56224ad39e1cb2960bde9a366a7b47deffa9955f
This commit is contained in:
Benjamin Kramer 2020-08-05 01:38:50 -07:00 committed by TensorFlower Gardener
parent b89e12c5a3
commit 5f6c13cb10
4 changed files with 415 additions and 0 deletions

View File

@ -4985,3 +4985,34 @@ cc_library(
"//tensorflow/stream_executor/lib",
],
)
cc_library(
name = "topk_rewriter",
srcs = ["topk_rewriter.cc"],
hdrs = ["topk_rewriter.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_pass",
":pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "topk_rewriter_test",
srcs = ["topk_rewriter_test.cc"],
deps = [
":hlo",
":hlo_dce",
":hlo_matchers",
":topk_rewriter",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_macros_cpu",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)

View File

@ -0,0 +1,187 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/topk_rewriter.h"
#include "absl/algorithm/container.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
static bool IsNanSafeGt(HloComputation* comp) {
namespace m = match;
auto match_bitcast_f32 = [](int64 parameter_number) {
auto param = m::Parameter(parameter_number)
.WithShape(m::Shape().WithElementType(F32));
auto param_s32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
auto param_u32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
return m::Select(
m::Lt(param_s32, m::ConstantScalar(0)),
m::BitcastConvert(
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
param_u32))
.WithShape(m::Shape().WithElementType(S32)),
param_s32);
};
auto match_bitcast_bf16 = [](int64 parameter_number) {
auto param = m::Convert(m::Parameter(parameter_number)
.WithShape(m::Shape().WithElementType(BF16)))
.WithShape(m::Shape().WithElementType(F32));
auto param_s32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
auto param_u32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
return m::Select(
m::Lt(param_s32, m::ConstantScalar(0)),
m::BitcastConvert(
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
param_u32))
.WithShape(m::Shape().WithElementType(S32)),
param_s32);
};
return Match(comp->root_instruction(),
m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
Match(comp->root_instruction(),
m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
}
StatusOr<bool> TopkRewriter::Run(HloModule* module) {
bool changed = false;
for (HloComputation* comp : module->computations()) {
for (HloInstruction* inst : comp->MakeInstructionPostOrder()) {
HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
if (sort == nullptr || sort->operand_count() != 2) {
continue;
}
HloInstruction* data = sort->mutable_operand(0);
HloIotaInstruction* iota =
DynCast<HloIotaInstruction>(sort->mutable_operand(1));
const PrimitiveType element_type = data->shape().element_type();
if (data->shape().rank() != 2 ||
(element_type != F32 && element_type != BF16)) {
continue;
}
if (iota == nullptr || iota->shape().rank() != 2 ||
iota->shape().element_type() != S32 ||
iota->opcode() != HloOpcode::kIota ||
iota->iota_dimension() != sort->sort_dimension()) {
continue;
}
if (!IsNanSafeGt(sort->to_apply())) {
continue;
}
const int64 sort_dim = sort->sort_dimension();
const int64 batch_dim = sort_dim == 1 ? 0 : 1;
bool supported = true;
absl::optional<int64> k;
for (HloInstruction* gte : sort->users()) {
if (gte->opcode() != HloOpcode::kGetTupleElement ||
gte->user_count() != 1) {
supported = false;
break;
}
const HloInstruction* slice = gte->users()[0];
if (slice->opcode() != HloOpcode::kSlice) {
// Non-slice user means we are not doing a TopK
supported = false;
break;
}
if (absl::c_any_of(slice->slice_starts(),
[](int x) { return x != 0; }) ||
absl::c_any_of(slice->slice_strides(),
[](int x) { return x != 1; })) {
// Strided slice or slicing at the beginning isn't supported.
supported = false;
break;
}
if (slice->slice_limits(batch_dim) !=
slice->operand(0)->shape().dimensions(batch_dim)) {
// Slicing along the batch dimension isn't supported.
supported = false;
break;
}
if (k == absl::nullopt) {
k = slice->slice_limits(sort_dim);
} else if (k != slice->slice_limits(sort_dim)) {
// Different k for the different operands isn't supported.
supported = false;
break;
}
}
if (k == absl::nullopt || !supported) {
continue;
}
// Profitability check.
if (!is_profitable_to_convert_(sort, *k)) {
continue;
}
const int64 batch_size = sort->operand(0)->shape().dimensions(batch_dim);
const int64 input_size = sort->operand(0)->shape().dimensions(sort_dim);
HloInstruction* input = sort->mutable_operand(0);
if (sort_dim == 0) {
input = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input,
{1, 0}));
}
Shape topk_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, {batch_size, k.value()}),
ShapeUtil::MakeShape(S32, {batch_size, k.value()})});
HloInstruction* topk = comp->AddInstruction(
HloInstruction::CreateCustomCall(topk_shape, {input}, "TopK"));
HloInstruction* value_gte =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(0), topk, 0));
HloInstruction* index_gte =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));
if (sort_dim == 0) {
value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
value_gte, {1, 0}));
index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
{1, 0}));
}
for (HloInstruction* gte : sort->users()) {
for (HloInstruction* slice : gte->users()) {
if (gte->tuple_index() == 0) {
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(value_gte));
} else if (gte->tuple_index() == 1) {
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(index_gte));
} else {
LOG(FATAL) << "Sort with more than 2 output isn't supported in "
"topk rewriter";
}
}
}
changed = true;
}
}
return changed;
}
} // namespace xla

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// This pass pattern-matches soups of HLOs executing a TopK operation and
// replaces them with a TopK CustomCall when the given values are supported by
// the CustomCall and it is more efficient to use that implementation.
class TopkRewriter : public HloModulePass {
public:
explicit TopkRewriter(std::function<bool(const HloSortInstruction*, int64)>
is_profitable_to_convert)
: is_profitable_to_convert_(std::move(is_profitable_to_convert)) {}
absl::string_view name() const override { return "topk-rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
private:
// Predicate that returns true if a sort instruction is profitable to be
// converted into a custom call.
std::function<bool(const HloSortInstruction*, int64)>
is_profitable_to_convert_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_

View File

@ -0,0 +1,153 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/topk_rewriter.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
using TopkRewriterTest = HloTestBase;
TEST_F(TopkRewriterTest, Rewrite) {
const char* const hlo_string = R"(
HloModule module
%compare {
%p.1.lhs.8 = s32[] parameter(2)
%p.1.rhs.9 = s32[] parameter(3)
%p.0.lhs.6 = f32[] parameter(0)
%bitcast-convert.11 = s32[] bitcast-convert(%p.0.lhs.6)
%constant.15 = s32[] constant(0)
%compare.16 = pred[] compare(%bitcast-convert.11, %constant.15), direction=LT
%constant.10 = u32[] constant(2147483647)
%bitcast-convert.12 = u32[] bitcast-convert(%p.0.lhs.6)
%subtract.13 = u32[] subtract(%constant.10, %bitcast-convert.12)
%bitcast-convert.14 = s32[] bitcast-convert(%subtract.13)
%select.17 = s32[] select(%compare.16, %bitcast-convert.14,
%bitcast-convert.11)
%p.0.rhs.7 = f32[] parameter(1)
%bitcast-convert.19 = s32[] bitcast-convert(%p.0.rhs.7)
%constant.23 = s32[] constant(0)
%compare.24 = pred[] compare(%bitcast-convert.19, %constant.23), direction=LT
%constant.18 = u32[] constant(2147483647)
%bitcast-convert.20 = u32[] bitcast-convert(%p.0.rhs.7)
%subtract.21 = u32[] subtract(%constant.18, %bitcast-convert.20)
%bitcast-convert.22 = s32[] bitcast-convert(%subtract.21)
%select.25 = s32[] select(%compare.24, %bitcast-convert.22,
%bitcast-convert.19)
ROOT %compare.26 = pred[] compare(%select.17, %select.25), direction=GT
}
ENTRY cluster {
%arg_tuple.1 = f32[8,1234567] parameter(0)
%iota.4 = s32[8,1234567] iota(), iota_dimension=1
%sort.27 = (f32[8,1234567], s32[8,1234567]) sort(%arg_tuple.1, %iota.4),
dimensions={1}, is_stable=true, to_apply=%compare
%get-tuple-element.28 = f32[8,1234567] get-tuple-element(%sort.27), index=0
%slice.29 = f32[8,5] slice(%get-tuple-element.28), slice={[0:8], [0:5]}
%get-tuple-element.30 = s32[8,1234567] get-tuple-element(%sort.27), index=1
%slice.31 = s32[8,5] slice(%get-tuple-element.30), slice={[0:8], [0:5]}
ROOT %tuple.32 = (f32[8,5], s32[8,5]) tuple(%slice.29, %slice.31)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; });
TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
TF_ASSERT_OK(HloDCE().Run(module.get()).status());
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(op::GetTupleElement(op::CustomCall(op::Parameter(0)), 0),
op::GetTupleElement(op::CustomCall(op::Parameter(0)), 1)));
const HloInstruction* cc =
module->entry_computation()->root_instruction()->operand(0)->operand(0);
EXPECT_THAT(cc->custom_call_target(), "TopK");
}
TEST_F(TopkRewriterTest, RewriteTranspose) {
const char* const hlo_string = R"(
HloModule module
%compare {
%p.1.lhs.8 = s32[] parameter(2)
%p.1.rhs.9 = s32[] parameter(3)
%p.0.lhs.6 = f32[] parameter(0)
%bitcast-convert.11 = s32[] bitcast-convert(%p.0.lhs.6)
%constant.15 = s32[] constant(0)
%compare.16 = pred[] compare(%bitcast-convert.11, %constant.15), direction=LT
%constant.10 = u32[] constant(2147483647)
%bitcast-convert.12 = u32[] bitcast-convert(%p.0.lhs.6)
%subtract.13 = u32[] subtract(%constant.10, %bitcast-convert.12)
%bitcast-convert.14 = s32[] bitcast-convert(%subtract.13)
%select.17 = s32[] select(%compare.16, %bitcast-convert.14,
%bitcast-convert.11)
%p.0.rhs.7 = f32[] parameter(1)
%bitcast-convert.19 = s32[] bitcast-convert(%p.0.rhs.7)
%constant.23 = s32[] constant(0)
%compare.24 = pred[] compare(%bitcast-convert.19, %constant.23), direction=LT
%constant.18 = u32[] constant(2147483647)
%bitcast-convert.20 = u32[] bitcast-convert(%p.0.rhs.7)
%subtract.21 = u32[] subtract(%constant.18, %bitcast-convert.20)
%bitcast-convert.22 = s32[] bitcast-convert(%subtract.21)
%select.25 = s32[] select(%compare.24, %bitcast-convert.22,
%bitcast-convert.19)
ROOT %compare.26 = pred[] compare(%select.17, %select.25), direction=GT
}
ENTRY cluster {
%arg_tuple.1 = f32[1234567,8] parameter(0)
%iota.4 = s32[1234567,8] iota(), iota_dimension=0
%sort.27 = (f32[1234567,8], s32[1234567,8]) sort(%arg_tuple.1, %iota.4),
dimensions={0}, is_stable=true, to_apply=%compare
%get-tuple-element.28 = f32[1234567,8] get-tuple-element(%sort.27), index=0
%slice.29 = f32[5,8] slice(%get-tuple-element.28), slice={[0:5], [0:8]}
%get-tuple-element.30 = s32[1234567,8] get-tuple-element(%sort.27), index=1
%slice.31 = s32[5,8] slice(%get-tuple-element.30), slice={[0:5], [0:8]}
ROOT %tuple.32 = (f32[5,8], s32[5,8]) tuple(%slice.29, %slice.31)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; });
TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
TF_ASSERT_OK(HloDCE().Run(module.get()).status());
EXPECT_TRUE(changed);
LOG(INFO) << module->entry_computation()->ToString();
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(op::Transpose(op::GetTupleElement(
op::CustomCall(op::Transpose(op::Parameter(0))), 0)),
op::Transpose(op::GetTupleElement(
op::CustomCall(op::Transpose(op::Parameter(0))), 1))));
const HloInstruction* cc = module->entry_computation()
->root_instruction()
->operand(0)
->operand(0)
->operand(0);
EXPECT_THAT(cc->custom_call_target(), "TopK");
}
} // namespace
} // namespace xla