[XLA] Add TopK rewriter pass
This pass pattern matches sort HLOs into a custom call. This will be useful for CPU. PiperOrigin-RevId: 324976268 Change-Id: I56224ad39e1cb2960bde9a366a7b47deffa9955f
This commit is contained in:
parent
b89e12c5a3
commit
5f6c13cb10
@ -4985,3 +4985,34 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "topk_rewriter",
|
||||
srcs = ["topk_rewriter.cc"],
|
||||
hdrs = ["topk_rewriter.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_pass",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "topk_rewriter_test",
|
||||
srcs = ["topk_rewriter_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_matchers",
|
||||
":topk_rewriter",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_macros_cpu",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
187
tensorflow/compiler/xla/service/topk_rewriter.cc
Normal file
187
tensorflow/compiler/xla/service/topk_rewriter.cc
Normal file
@ -0,0 +1,187 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/topk_rewriter.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
static bool IsNanSafeGt(HloComputation* comp) {
|
||||
namespace m = match;
|
||||
auto match_bitcast_f32 = [](int64 parameter_number) {
|
||||
auto param = m::Parameter(parameter_number)
|
||||
.WithShape(m::Shape().WithElementType(F32));
|
||||
auto param_s32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
|
||||
auto param_u32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
|
||||
return m::Select(
|
||||
m::Lt(param_s32, m::ConstantScalar(0)),
|
||||
m::BitcastConvert(
|
||||
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
|
||||
param_u32))
|
||||
.WithShape(m::Shape().WithElementType(S32)),
|
||||
param_s32);
|
||||
};
|
||||
auto match_bitcast_bf16 = [](int64 parameter_number) {
|
||||
auto param = m::Convert(m::Parameter(parameter_number)
|
||||
.WithShape(m::Shape().WithElementType(BF16)))
|
||||
.WithShape(m::Shape().WithElementType(F32));
|
||||
auto param_s32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
|
||||
auto param_u32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
|
||||
return m::Select(
|
||||
m::Lt(param_s32, m::ConstantScalar(0)),
|
||||
m::BitcastConvert(
|
||||
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
|
||||
param_u32))
|
||||
.WithShape(m::Shape().WithElementType(S32)),
|
||||
param_s32);
|
||||
};
|
||||
return Match(comp->root_instruction(),
|
||||
m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
|
||||
Match(comp->root_instruction(),
|
||||
m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
|
||||
}
|
||||
|
||||
StatusOr<bool> TopkRewriter::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* comp : module->computations()) {
|
||||
for (HloInstruction* inst : comp->MakeInstructionPostOrder()) {
|
||||
HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
|
||||
if (sort == nullptr || sort->operand_count() != 2) {
|
||||
continue;
|
||||
}
|
||||
HloInstruction* data = sort->mutable_operand(0);
|
||||
HloIotaInstruction* iota =
|
||||
DynCast<HloIotaInstruction>(sort->mutable_operand(1));
|
||||
const PrimitiveType element_type = data->shape().element_type();
|
||||
if (data->shape().rank() != 2 ||
|
||||
(element_type != F32 && element_type != BF16)) {
|
||||
continue;
|
||||
}
|
||||
if (iota == nullptr || iota->shape().rank() != 2 ||
|
||||
iota->shape().element_type() != S32 ||
|
||||
iota->opcode() != HloOpcode::kIota ||
|
||||
iota->iota_dimension() != sort->sort_dimension()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsNanSafeGt(sort->to_apply())) {
|
||||
continue;
|
||||
}
|
||||
const int64 sort_dim = sort->sort_dimension();
|
||||
const int64 batch_dim = sort_dim == 1 ? 0 : 1;
|
||||
|
||||
bool supported = true;
|
||||
absl::optional<int64> k;
|
||||
for (HloInstruction* gte : sort->users()) {
|
||||
if (gte->opcode() != HloOpcode::kGetTupleElement ||
|
||||
gte->user_count() != 1) {
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
const HloInstruction* slice = gte->users()[0];
|
||||
if (slice->opcode() != HloOpcode::kSlice) {
|
||||
// Non-slice user means we are not doing a TopK
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
if (absl::c_any_of(slice->slice_starts(),
|
||||
[](int x) { return x != 0; }) ||
|
||||
absl::c_any_of(slice->slice_strides(),
|
||||
[](int x) { return x != 1; })) {
|
||||
// Strided slice or slicing at the beginning isn't supported.
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
if (slice->slice_limits(batch_dim) !=
|
||||
slice->operand(0)->shape().dimensions(batch_dim)) {
|
||||
// Slicing along the batch dimension isn't supported.
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
if (k == absl::nullopt) {
|
||||
k = slice->slice_limits(sort_dim);
|
||||
} else if (k != slice->slice_limits(sort_dim)) {
|
||||
// Different k for the different operands isn't supported.
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (k == absl::nullopt || !supported) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Profitability check.
|
||||
if (!is_profitable_to_convert_(sort, *k)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64 batch_size = sort->operand(0)->shape().dimensions(batch_dim);
|
||||
const int64 input_size = sort->operand(0)->shape().dimensions(sort_dim);
|
||||
HloInstruction* input = sort->mutable_operand(0);
|
||||
if (sort_dim == 0) {
|
||||
input = comp->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input,
|
||||
{1, 0}));
|
||||
}
|
||||
|
||||
Shape topk_shape = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(element_type, {batch_size, k.value()}),
|
||||
ShapeUtil::MakeShape(S32, {batch_size, k.value()})});
|
||||
HloInstruction* topk = comp->AddInstruction(
|
||||
HloInstruction::CreateCustomCall(topk_shape, {input}, "TopK"));
|
||||
HloInstruction* value_gte =
|
||||
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
topk->shape().tuple_shapes(0), topk, 0));
|
||||
HloInstruction* index_gte =
|
||||
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
topk->shape().tuple_shapes(1), topk, 1));
|
||||
|
||||
if (sort_dim == 0) {
|
||||
value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
|
||||
value_gte, {1, 0}));
|
||||
index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
|
||||
{1, 0}));
|
||||
}
|
||||
|
||||
for (HloInstruction* gte : sort->users()) {
|
||||
for (HloInstruction* slice : gte->users()) {
|
||||
if (gte->tuple_index() == 0) {
|
||||
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(value_gte));
|
||||
} else if (gte->tuple_index() == 1) {
|
||||
TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(index_gte));
|
||||
} else {
|
||||
LOG(FATAL) << "Sort with more than 2 output isn't supported in "
|
||||
"topk rewriter";
|
||||
}
|
||||
}
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
44
tensorflow/compiler/xla/service/topk_rewriter.h
Normal file
44
tensorflow/compiler/xla/service/topk_rewriter.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
// This pass pattern-matches soups of HLOs executing a TopK operation and
|
||||
// replaces them with a TopK CustomCall when the given values are supported by
|
||||
// the CustomCall and it is more efficient to use that implementation.
|
||||
class TopkRewriter : public HloModulePass {
|
||||
public:
|
||||
explicit TopkRewriter(std::function<bool(const HloSortInstruction*, int64)>
|
||||
is_profitable_to_convert)
|
||||
: is_profitable_to_convert_(std::move(is_profitable_to_convert)) {}
|
||||
|
||||
absl::string_view name() const override { return "topk-rewriter"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Predicate that returns true if a sort instruction is profitable to be
|
||||
// converted into a custom call.
|
||||
std::function<bool(const HloSortInstruction*, int64)>
|
||||
is_profitable_to_convert_;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_
|
153
tensorflow/compiler/xla/service/topk_rewriter_test.cc
Normal file
153
tensorflow/compiler/xla/service/topk_rewriter_test.cc
Normal file
@ -0,0 +1,153 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/topk_rewriter.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using TopkRewriterTest = HloTestBase;
|
||||
|
||||
TEST_F(TopkRewriterTest, Rewrite) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%compare {
|
||||
%p.1.lhs.8 = s32[] parameter(2)
|
||||
%p.1.rhs.9 = s32[] parameter(3)
|
||||
%p.0.lhs.6 = f32[] parameter(0)
|
||||
%bitcast-convert.11 = s32[] bitcast-convert(%p.0.lhs.6)
|
||||
%constant.15 = s32[] constant(0)
|
||||
%compare.16 = pred[] compare(%bitcast-convert.11, %constant.15), direction=LT
|
||||
%constant.10 = u32[] constant(2147483647)
|
||||
%bitcast-convert.12 = u32[] bitcast-convert(%p.0.lhs.6)
|
||||
%subtract.13 = u32[] subtract(%constant.10, %bitcast-convert.12)
|
||||
%bitcast-convert.14 = s32[] bitcast-convert(%subtract.13)
|
||||
%select.17 = s32[] select(%compare.16, %bitcast-convert.14,
|
||||
%bitcast-convert.11)
|
||||
%p.0.rhs.7 = f32[] parameter(1)
|
||||
%bitcast-convert.19 = s32[] bitcast-convert(%p.0.rhs.7)
|
||||
%constant.23 = s32[] constant(0)
|
||||
%compare.24 = pred[] compare(%bitcast-convert.19, %constant.23), direction=LT
|
||||
%constant.18 = u32[] constant(2147483647)
|
||||
%bitcast-convert.20 = u32[] bitcast-convert(%p.0.rhs.7)
|
||||
%subtract.21 = u32[] subtract(%constant.18, %bitcast-convert.20)
|
||||
%bitcast-convert.22 = s32[] bitcast-convert(%subtract.21)
|
||||
%select.25 = s32[] select(%compare.24, %bitcast-convert.22,
|
||||
%bitcast-convert.19)
|
||||
ROOT %compare.26 = pred[] compare(%select.17, %select.25), direction=GT
|
||||
}
|
||||
|
||||
ENTRY cluster {
|
||||
%arg_tuple.1 = f32[8,1234567] parameter(0)
|
||||
%iota.4 = s32[8,1234567] iota(), iota_dimension=1
|
||||
%sort.27 = (f32[8,1234567], s32[8,1234567]) sort(%arg_tuple.1, %iota.4),
|
||||
dimensions={1}, is_stable=true, to_apply=%compare
|
||||
%get-tuple-element.28 = f32[8,1234567] get-tuple-element(%sort.27), index=0
|
||||
%slice.29 = f32[8,5] slice(%get-tuple-element.28), slice={[0:8], [0:5]}
|
||||
%get-tuple-element.30 = s32[8,1234567] get-tuple-element(%sort.27), index=1
|
||||
%slice.31 = s32[8,5] slice(%get-tuple-element.30), slice={[0:8], [0:5]}
|
||||
ROOT %tuple.32 = (f32[8,5], s32[8,5]) tuple(%slice.29, %slice.31)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; });
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
|
||||
TF_ASSERT_OK(HloDCE().Run(module.get()).status());
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(op::CustomCall(op::Parameter(0)), 0),
|
||||
op::GetTupleElement(op::CustomCall(op::Parameter(0)), 1)));
|
||||
const HloInstruction* cc =
|
||||
module->entry_computation()->root_instruction()->operand(0)->operand(0);
|
||||
EXPECT_THAT(cc->custom_call_target(), "TopK");
|
||||
}
|
||||
|
||||
TEST_F(TopkRewriterTest, RewriteTranspose) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%compare {
|
||||
%p.1.lhs.8 = s32[] parameter(2)
|
||||
%p.1.rhs.9 = s32[] parameter(3)
|
||||
%p.0.lhs.6 = f32[] parameter(0)
|
||||
%bitcast-convert.11 = s32[] bitcast-convert(%p.0.lhs.6)
|
||||
%constant.15 = s32[] constant(0)
|
||||
%compare.16 = pred[] compare(%bitcast-convert.11, %constant.15), direction=LT
|
||||
%constant.10 = u32[] constant(2147483647)
|
||||
%bitcast-convert.12 = u32[] bitcast-convert(%p.0.lhs.6)
|
||||
%subtract.13 = u32[] subtract(%constant.10, %bitcast-convert.12)
|
||||
%bitcast-convert.14 = s32[] bitcast-convert(%subtract.13)
|
||||
%select.17 = s32[] select(%compare.16, %bitcast-convert.14,
|
||||
%bitcast-convert.11)
|
||||
%p.0.rhs.7 = f32[] parameter(1)
|
||||
%bitcast-convert.19 = s32[] bitcast-convert(%p.0.rhs.7)
|
||||
%constant.23 = s32[] constant(0)
|
||||
%compare.24 = pred[] compare(%bitcast-convert.19, %constant.23), direction=LT
|
||||
%constant.18 = u32[] constant(2147483647)
|
||||
%bitcast-convert.20 = u32[] bitcast-convert(%p.0.rhs.7)
|
||||
%subtract.21 = u32[] subtract(%constant.18, %bitcast-convert.20)
|
||||
%bitcast-convert.22 = s32[] bitcast-convert(%subtract.21)
|
||||
%select.25 = s32[] select(%compare.24, %bitcast-convert.22,
|
||||
%bitcast-convert.19)
|
||||
ROOT %compare.26 = pred[] compare(%select.17, %select.25), direction=GT
|
||||
}
|
||||
|
||||
ENTRY cluster {
|
||||
%arg_tuple.1 = f32[1234567,8] parameter(0)
|
||||
%iota.4 = s32[1234567,8] iota(), iota_dimension=0
|
||||
%sort.27 = (f32[1234567,8], s32[1234567,8]) sort(%arg_tuple.1, %iota.4),
|
||||
dimensions={0}, is_stable=true, to_apply=%compare
|
||||
%get-tuple-element.28 = f32[1234567,8] get-tuple-element(%sort.27), index=0
|
||||
%slice.29 = f32[5,8] slice(%get-tuple-element.28), slice={[0:5], [0:8]}
|
||||
%get-tuple-element.30 = s32[1234567,8] get-tuple-element(%sort.27), index=1
|
||||
%slice.31 = s32[5,8] slice(%get-tuple-element.30), slice={[0:5], [0:8]}
|
||||
ROOT %tuple.32 = (f32[5,8], s32[5,8]) tuple(%slice.29, %slice.31)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; });
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
|
||||
TF_ASSERT_OK(HloDCE().Run(module.get()).status());
|
||||
EXPECT_TRUE(changed);
|
||||
LOG(INFO) << module->entry_computation()->ToString();
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::Transpose(op::GetTupleElement(
|
||||
op::CustomCall(op::Transpose(op::Parameter(0))), 0)),
|
||||
op::Transpose(op::GetTupleElement(
|
||||
op::CustomCall(op::Transpose(op::Parameter(0))), 1))));
|
||||
const HloInstruction* cc = module->entry_computation()
|
||||
->root_instruction()
|
||||
->operand(0)
|
||||
->operand(0)
|
||||
->operand(0);
|
||||
EXPECT_THAT(cc->custom_call_target(), "TopK");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user