diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4d15bc432a2..49431b19a69 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4985,3 +4985,34 @@ cc_library( "//tensorflow/stream_executor/lib", ], ) + +cc_library( + name = "topk_rewriter", + srcs = ["topk_rewriter.cc"], + hdrs = ["topk_rewriter.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":pattern_matcher", + "//tensorflow/compiler/xla:shape_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "topk_rewriter_test", + srcs = ["topk_rewriter_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_matchers", + ":topk_rewriter", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_macros_cpu", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/topk_rewriter.cc b/tensorflow/compiler/xla/service/topk_rewriter.cc new file mode 100644 index 00000000000..ae843760a8d --- /dev/null +++ b/tensorflow/compiler/xla/service/topk_rewriter.cc @@ -0,0 +1,187 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/topk_rewriter.h" + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +static bool IsNanSafeGt(HloComputation* comp) { + namespace m = match; + auto match_bitcast_f32 = [](int64 parameter_number) { + auto param = m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(F32)); + auto param_s32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32)); + auto param_u32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32)); + return m::Select( + m::Lt(param_s32, m::ConstantScalar(0)), + m::BitcastConvert( + m::Subtract(m::ConstantScalar(std::numeric_limits::max()), + param_u32)) + .WithShape(m::Shape().WithElementType(S32)), + param_s32); + }; + auto match_bitcast_bf16 = [](int64 parameter_number) { + auto param = m::Convert(m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(BF16))) + .WithShape(m::Shape().WithElementType(F32)); + auto param_s32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32)); + auto param_u32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32)); + return m::Select( + m::Lt(param_s32, m::ConstantScalar(0)), + m::BitcastConvert( + m::Subtract(m::ConstantScalar(std::numeric_limits::max()), + param_u32)) + .WithShape(m::Shape().WithElementType(S32)), + param_s32); + }; + return Match(comp->root_instruction(), + m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) || + Match(comp->root_instruction(), + m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1))); +} + +StatusOr TopkRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->computations()) { + for (HloInstruction* inst : comp->MakeInstructionPostOrder()) { + HloSortInstruction* sort = DynCast(inst); + if (sort == nullptr || sort->operand_count() != 2) { + continue; + } + HloInstruction* data = sort->mutable_operand(0); + HloIotaInstruction* iota = + DynCast(sort->mutable_operand(1)); + const PrimitiveType element_type = data->shape().element_type(); + if (data->shape().rank() != 2 || + (element_type != F32 && element_type != BF16)) { + continue; + } + if (iota == nullptr || iota->shape().rank() != 2 || + iota->shape().element_type() != S32 || + iota->opcode() != HloOpcode::kIota || + iota->iota_dimension() != sort->sort_dimension()) { + continue; + } + if (!IsNanSafeGt(sort->to_apply())) { + continue; + } + const int64 sort_dim = sort->sort_dimension(); + const int64 batch_dim = sort_dim == 1 ? 0 : 1; + + bool supported = true; + absl::optional k; + for (HloInstruction* gte : sort->users()) { + if (gte->opcode() != HloOpcode::kGetTupleElement || + gte->user_count() != 1) { + supported = false; + break; + } + const HloInstruction* slice = gte->users()[0]; + if (slice->opcode() != HloOpcode::kSlice) { + // Non-slice user means we are not doing a TopK + supported = false; + break; + } + if (absl::c_any_of(slice->slice_starts(), + [](int x) { return x != 0; }) || + absl::c_any_of(slice->slice_strides(), + [](int x) { return x != 1; })) { + // Strided slice or slicing at the beginning isn't supported. + supported = false; + break; + } + if (slice->slice_limits(batch_dim) != + slice->operand(0)->shape().dimensions(batch_dim)) { + // Slicing along the batch dimension isn't supported. + supported = false; + break; + } + if (k == absl::nullopt) { + k = slice->slice_limits(sort_dim); + } else if (k != slice->slice_limits(sort_dim)) { + // Different k for the different operands isn't supported. + supported = false; + break; + } + } + if (k == absl::nullopt || !supported) { + continue; + } + + // Profitability check. + if (!is_profitable_to_convert_(sort, *k)) { + continue; + } + + const int64 batch_size = sort->operand(0)->shape().dimensions(batch_dim); + const int64 input_size = sort->operand(0)->shape().dimensions(sort_dim); + HloInstruction* input = sort->mutable_operand(0); + if (sort_dim == 0) { + input = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, + {1, 0})); + } + + Shape topk_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(element_type, {batch_size, k.value()}), + ShapeUtil::MakeShape(S32, {batch_size, k.value()})}); + HloInstruction* topk = comp->AddInstruction( + HloInstruction::CreateCustomCall(topk_shape, {input}, "TopK")); + HloInstruction* value_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(0), topk, 0)); + HloInstruction* index_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + + if (sort_dim == 0) { + value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), + value_gte, {1, 0})); + index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, + {1, 0})); + } + + for (HloInstruction* gte : sort->users()) { + for (HloInstruction* slice : gte->users()) { + if (gte->tuple_index() == 0) { + TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(value_gte)); + } else if (gte->tuple_index() == 1) { + TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(index_gte)); + } else { + LOG(FATAL) << "Sort with more than 2 output isn't supported in " + "topk rewriter"; + } + } + } + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/topk_rewriter.h b/tensorflow/compiler/xla/service/topk_rewriter.h new file mode 100644 index 00000000000..68f8a8145e2 --- /dev/null +++ b/tensorflow/compiler/xla/service/topk_rewriter.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +// This pass pattern-matches soups of HLOs executing a TopK operation and +// replaces them with a TopK CustomCall when the given values are supported by +// the CustomCall and it is more efficient to use that implementation. +class TopkRewriter : public HloModulePass { + public: + explicit TopkRewriter(std::function + is_profitable_to_convert) + : is_profitable_to_convert_(std::move(is_profitable_to_convert)) {} + + absl::string_view name() const override { return "topk-rewriter"; } + + StatusOr Run(HloModule* module) override; + + private: + // Predicate that returns true if a sort instruction is profitable to be + // converted into a custom call. + std::function + is_profitable_to_convert_; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/topk_rewriter_test.cc b/tensorflow/compiler/xla/service/topk_rewriter_test.cc new file mode 100644 index 00000000000..e440da5b163 --- /dev/null +++ b/tensorflow/compiler/xla/service/topk_rewriter_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/topk_rewriter.h" + +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using TopkRewriterTest = HloTestBase; + +TEST_F(TopkRewriterTest, Rewrite) { + const char* const hlo_string = R"( +HloModule module + +%compare { + %p.1.lhs.8 = s32[] parameter(2) + %p.1.rhs.9 = s32[] parameter(3) + %p.0.lhs.6 = f32[] parameter(0) + %bitcast-convert.11 = s32[] bitcast-convert(%p.0.lhs.6) + %constant.15 = s32[] constant(0) + %compare.16 = pred[] compare(%bitcast-convert.11, %constant.15), direction=LT + %constant.10 = u32[] constant(2147483647) + %bitcast-convert.12 = u32[] bitcast-convert(%p.0.lhs.6) + %subtract.13 = u32[] subtract(%constant.10, %bitcast-convert.12) + %bitcast-convert.14 = s32[] bitcast-convert(%subtract.13) + %select.17 = s32[] select(%compare.16, %bitcast-convert.14, + %bitcast-convert.11) + %p.0.rhs.7 = f32[] parameter(1) + %bitcast-convert.19 = s32[] bitcast-convert(%p.0.rhs.7) + %constant.23 = s32[] constant(0) + %compare.24 = pred[] compare(%bitcast-convert.19, %constant.23), direction=LT + %constant.18 = u32[] constant(2147483647) + %bitcast-convert.20 = u32[] bitcast-convert(%p.0.rhs.7) + %subtract.21 = u32[] subtract(%constant.18, %bitcast-convert.20) + %bitcast-convert.22 = s32[] bitcast-convert(%subtract.21) + %select.25 = s32[] select(%compare.24, %bitcast-convert.22, + %bitcast-convert.19) + ROOT %compare.26 = pred[] compare(%select.17, %select.25), direction=GT +} + +ENTRY cluster { + %arg_tuple.1 = f32[8,1234567] parameter(0) + %iota.4 = s32[8,1234567] iota(), iota_dimension=1 + %sort.27 = (f32[8,1234567], s32[8,1234567]) sort(%arg_tuple.1, %iota.4), + dimensions={1}, is_stable=true, to_apply=%compare + %get-tuple-element.28 = f32[8,1234567] get-tuple-element(%sort.27), index=0 + %slice.29 = f32[8,5] slice(%get-tuple-element.28), slice={[0:8], [0:5]} + %get-tuple-element.30 = s32[8,1234567] get-tuple-element(%sort.27), index=1 + %slice.31 = s32[8,5] slice(%get-tuple-element.30), slice={[0:8], [0:5]} + ROOT %tuple.32 = (f32[8,5], s32[8,5]) tuple(%slice.29, %slice.31) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + TF_ASSERT_OK(HloDCE().Run(module.get()).status()); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::GetTupleElement(op::CustomCall(op::Parameter(0)), 0), + op::GetTupleElement(op::CustomCall(op::Parameter(0)), 1))); + const HloInstruction* cc = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + EXPECT_THAT(cc->custom_call_target(), "TopK"); +} + +TEST_F(TopkRewriterTest, RewriteTranspose) { + const char* const hlo_string = R"( +HloModule module + +%compare { + %p.1.lhs.8 = s32[] parameter(2) + %p.1.rhs.9 = s32[] parameter(3) + %p.0.lhs.6 = f32[] parameter(0) + %bitcast-convert.11 = s32[] bitcast-convert(%p.0.lhs.6) + %constant.15 = s32[] constant(0) + %compare.16 = pred[] compare(%bitcast-convert.11, %constant.15), direction=LT + %constant.10 = u32[] constant(2147483647) + %bitcast-convert.12 = u32[] bitcast-convert(%p.0.lhs.6) + %subtract.13 = u32[] subtract(%constant.10, %bitcast-convert.12) + %bitcast-convert.14 = s32[] bitcast-convert(%subtract.13) + %select.17 = s32[] select(%compare.16, %bitcast-convert.14, + %bitcast-convert.11) + %p.0.rhs.7 = f32[] parameter(1) + %bitcast-convert.19 = s32[] bitcast-convert(%p.0.rhs.7) + %constant.23 = s32[] constant(0) + %compare.24 = pred[] compare(%bitcast-convert.19, %constant.23), direction=LT + %constant.18 = u32[] constant(2147483647) + %bitcast-convert.20 = u32[] bitcast-convert(%p.0.rhs.7) + %subtract.21 = u32[] subtract(%constant.18, %bitcast-convert.20) + %bitcast-convert.22 = s32[] bitcast-convert(%subtract.21) + %select.25 = s32[] select(%compare.24, %bitcast-convert.22, + %bitcast-convert.19) + ROOT %compare.26 = pred[] compare(%select.17, %select.25), direction=GT +} + +ENTRY cluster { + %arg_tuple.1 = f32[1234567,8] parameter(0) + %iota.4 = s32[1234567,8] iota(), iota_dimension=0 + %sort.27 = (f32[1234567,8], s32[1234567,8]) sort(%arg_tuple.1, %iota.4), + dimensions={0}, is_stable=true, to_apply=%compare + %get-tuple-element.28 = f32[1234567,8] get-tuple-element(%sort.27), index=0 + %slice.29 = f32[5,8] slice(%get-tuple-element.28), slice={[0:5], [0:8]} + %get-tuple-element.30 = s32[1234567,8] get-tuple-element(%sort.27), index=1 + %slice.31 = s32[5,8] slice(%get-tuple-element.30), slice={[0:5], [0:8]} + ROOT %tuple.32 = (f32[5,8], s32[5,8]) tuple(%slice.29, %slice.31) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + TF_ASSERT_OK(HloDCE().Run(module.get()).status()); + EXPECT_TRUE(changed); + LOG(INFO) << module->entry_computation()->ToString(); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Transpose(op::GetTupleElement( + op::CustomCall(op::Transpose(op::Parameter(0))), 0)), + op::Transpose(op::GetTupleElement( + op::CustomCall(op::Transpose(op::Parameter(0))), 1)))); + const HloInstruction* cc = module->entry_computation() + ->root_instruction() + ->operand(0) + ->operand(0) + ->operand(0); + EXPECT_THAT(cc->custom_call_target(), "TopK"); +} + +} // namespace +} // namespace xla