From 926642ea7a1a8095cddf3b063c5d19d4478eaeda Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 5 Aug 2020 02:05:57 -0700 Subject: [PATCH] [XLA:CPU] Add a runtime function for F32 TopK and use TopkRewriter to target it This just delegates the hard work to std::partial_sort. PiperOrigin-RevId: 324979038 Change-Id: I16a7c4d840948f3744f4f920bac29d4e6d7333b3 --- tensorflow/compiler/tests/sort_ops_test.py | 59 +++++++------- tensorflow/compiler/xla/service/cpu/BUILD | 17 +++++ .../compiler/xla/service/cpu/cpu_compiler.cc | 4 + .../compiler/xla/service/cpu/cpu_runtime.cc | 1 + .../compiler/xla/service/cpu/cpu_runtime.h | 1 + .../compiler/xla/service/cpu/ir_emitter.cc | 38 ++++++++++ .../compiler/xla/service/cpu/ir_emitter.h | 1 + .../compiler/xla/service/cpu/runtime_topk.cc | 76 +++++++++++++++++++ .../compiler/xla/service/cpu/runtime_topk.h | 32 ++++++++ .../xla/service/cpu/simple_orc_jit.cc | 2 + .../compiler/xla/service/cpu/tests/BUILD | 16 ++++ .../xla/service/cpu/tests/cpu_topk_test.cc | 59 ++++++++++++++ 12 files changed, 273 insertions(+), 33 deletions(-) create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_topk.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_topk.h create mode 100644 tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index d50fdec7c63..838718aa1e3 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -129,42 +129,35 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # Only bfloat16 is implemented. - bfloat16 = dtypes.bfloat16.as_numpy_dtype - if bfloat16 not in self.numeric_types: - return - - with self.session() as sess: - p = array_ops.placeholder(dtypes.bfloat16) - with self.test_scope(): - topk = nn_ops.top_k(p, k=4) - results = sess.run( - topk, - {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) - self.assertAllEqual( - np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) - self.assertEqual(list([3, 0, 2, 6]), list(results[1])) + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + with self.session() as sess: + p = array_ops.placeholder(dtype) + with self.test_scope(): + topk = nn_ops.top_k(p, k=4) + results = sess.run( + topk, + {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=dtype)}) + self.assertAllEqual(np.array([3., 0., 0., 0.], dtype=dtype), results[0]) + self.assertEqual(list([3, 0, 2, 6]), list(results[1])) def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" - # Only bfloat16 is implemented. - bfloat16 = dtypes.bfloat16.as_numpy_dtype - if bfloat16 not in self.numeric_types: - return - - with self.session() as sess: - p = array_ops.placeholder(dtypes.bfloat16) - with self.test_scope(): - topk = nn_ops.top_k(p, k=6) - results = sess.run(topk, { - p: np.array( - [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) - }) - self.assertAllEqual( - np.array( - [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], - dtype=bfloat16), results[0]) - self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + with self.session() as sess: + p = array_ops.placeholder(dtype) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: + np.array([1, 2, float("inf"), -float("inf"), -1, -2], + dtype=dtype) + }) + self.assertAllEqual( + np.array([float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=dtype), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) def testInTopK(self): supported_types = set([np.int32, np.int64]) diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 782d08296f0..6eaf43902fe 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -49,6 +49,7 @@ filegroup( "runtime_single_threaded_conv2d.cc", "runtime_single_threaded_fft.cc", "runtime_single_threaded_matmul.cc", + "runtime_topk.cc", ], visibility = [":friends"], ) @@ -64,6 +65,7 @@ filegroup( "runtime_single_threaded_conv2d.h", "runtime_single_threaded_fft.h", "runtime_single_threaded_matmul.h", + "runtime_topk.h", ], visibility = [":friends"], ) @@ -134,6 +136,7 @@ cc_library( "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", + "//tensorflow/compiler/xla/service:topk_rewriter", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:tree_reduction_rewriter", @@ -230,6 +233,7 @@ cc_library( ":runtime_fft", ":runtime_fork_join", ":runtime_key_value_sort", + ":runtime_topk", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", @@ -759,6 +763,19 @@ cc_library( ], ) +cc_library( + name = "runtime_topk", + srcs = ["runtime_topk.cc"], + hdrs = ["runtime_topk.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform:dynamic_annotations", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "runtime_fork_join", srcs = ["runtime_fork_join.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 04d703fdd59..0826d7b8ce1 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -104,6 +104,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/slice_sinker.h" #include "tensorflow/compiler/xla/service/slow_operation_alarm.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/topk_rewriter.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h" #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" @@ -320,6 +321,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddPass(); pass.AddPass(); } + pipeline.AddPass([](const HloSortInstruction* sort, int64) { + return sort->operand(0)->shape().element_type() == F32; + }); pipeline.AddPass(); pipeline.AddPass( [&](const HloInstruction& dot, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 2231ecfa1e8..5bee6049a5e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -117,6 +117,7 @@ extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; extern const char* const kKeyValueSortSymbolName = "__xla_cpu_runtime_KeyValueSort"; +extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32"; extern const char* const kTracingStartSymbolName = "__xla_cpu_runtime_TracingStart"; extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index ee75b97e4dc..eb24e0bc334 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -72,6 +72,7 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; extern const char* const kKeyValueSortSymbolName; +extern const char* const kTopKF32SymbolName; extern const char* const kAllReduceSymbolName; extern const char* const kCollectivePermuteSymbolName; extern const char* const kReplicaIdSymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 278e6479e48..2688a7898af 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2387,6 +2387,41 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { return Status::OK(); } +Status IrEmitter::HandleTopK(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + const HloInstruction* input = hlo->operand(0); + int64 k = hlo->shape().tuple_shapes(0).dimensions(1); + TF_RET_CHECK(input->shape().element_type() == F32); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( + hlo->shape().tuple_shapes(0).layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( + hlo->shape().tuple_shapes(1).layout())); + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())); + + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice, + assignment_.GetUniqueSlice(hlo->operand(0), {})); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice, + assignment_.GetUniqueSlice(hlo, {0})); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice, + assignment_.GetUniqueSlice(hlo, {1})); + llvm::Value* values_ptr = + EmitBufferPointer(values_slice, hlo->operand(0)->shape()); + llvm::Value* out_values_ptr = + EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0)); + llvm::Value* out_indices_ptr = + EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1)); + EmitCallToFunc(runtime::kTopKF32SymbolName, + {b_.getInt64(input->shape().dimensions(0)), + b_.getInt64(input->shape().dimensions(1)), b_.getInt64(k), + values_ptr, out_values_ptr, out_indices_ptr}, + b_.getVoidTy()); + + llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr}, + &b_); + return Status::OK(); +} + Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { if (custom_call->custom_call_target() == "PadToStatic") { return HandlePadToStatic(custom_call); @@ -2394,6 +2429,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { if (custom_call->custom_call_target() == "SliceToDynamic") { return HandleSliceToDynamic(custom_call); } + if (custom_call->custom_call_target() == "TopK") { + return HandleTopK(custom_call); + } absl::Span operands(custom_call->operands()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3955deefbea..f136e3470e5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -190,6 +190,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, private: Status HandleSliceToDynamic(HloInstruction* hlo); Status HandlePadToStatic(HloInstruction* hlo); + Status HandleTopK(HloInstruction* hlo); Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_topk.cc b/tensorflow/compiler/xla/service/cpu/runtime_topk.cc new file mode 100644 index 00000000000..5174a3329fb --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_topk.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h" + +#include +#include +#include +#include + +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" + +template +static void TopK(tensorflow::int64 batch_size, tensorflow::int64 input_size, + tensorflow::int64 k, const T* values, T* out_values, + tensorflow::int32* out_indices) { + // 'values' is managed by the JIT code, so msan can't tell they are + // initialized. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, + input_size * batch_size * sizeof(T)); + + std::vector temp_indices(input_size); + for (tensorflow::int64 batch = 0; batch != batch_size; ++batch) { + std::iota(temp_indices.begin(), temp_indices.end(), 0); + + const T* values_batch = values + batch * input_size; + + auto convert_to_int = [](T value) { + tensorflow::uint32 x; + std::memcpy(&x, &value, sizeof(x)); + return static_cast(x) < 0 + ? std::numeric_limits::max() - x + : x; + }; + + auto kth_element = temp_indices.begin() + k; + std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(), + [&](size_t i1, size_t i2) { + // Do the comparison in integers to enforce a total + // order of -NaN < -Inf < -0 < +0 < +Inf < +NaN. + tensorflow::int32 v1 = convert_to_int(values_batch[i1]); + tensorflow::int32 v2 = convert_to_int(values_batch[i2]); + if (v1 == v2) { + return i1 < i2; // Stabilize sorting. + } + return v1 > v2; + }); + + T* out_values_batch = out_values + batch * k; + tensorflow::int32* out_indices_batch = out_indices + batch * k; + std::copy(temp_indices.begin(), kth_element, out_indices_batch); + for (tensorflow::int64 i = 0; i < k; i++) { + out_values_batch[i] = values_batch[temp_indices[i]]; + } + } +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32( + tensorflow::int64 batch_size, tensorflow::int64 input_size, + tensorflow::int64 k, const float* values, float* out_values, + tensorflow::int32* out_indices) { + TopK(batch_size, input_size, k, values, out_values, out_indices); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_topk.h b/tensorflow/compiler/xla/service/cpu/runtime_topk.h new file mode 100644 index 00000000000..de69c0603e3 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_topk.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Calculates `batch_size` topk operations with `input_size` inputs each. The +// outputs are written to `out_values` and `out_indices`. +extern void __xla_cpu_runtime_TopKF32(tensorflow::int64 batch_size, + tensorflow::int64 input_size, + tensorflow::int64 k, const float* values, + float* out_values, + tensorflow::int32* out_indices); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 631c6985b03..28508bde4cd 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/types.h" @@ -270,6 +271,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); + REGISTER_CPU_RUNTIME_SYMBOL(TopKF32); REGISTER_CPU_RUNTIME_SYMBOL(TracingStart); REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index d7c50dce3ca..527071d5f31 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -253,6 +253,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "cpu_topk_test", + srcs = ["cpu_topk_test.cc"], + deps = [ + ":cpu_codegen_test", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:test_header_helper", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "cpu_vectorization_test", srcs = ["cpu_vectorization_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc new file mode 100644 index 00000000000..a4c74cfb8a2 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" + +namespace xla { +namespace cpu { +namespace { + +using CpuTopKTest = CpuCodegenTest; + +TEST_F(CpuTopKTest, CallRuntime) { + XlaBuilder builder(TestName()); + XlaOp input = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {5, 100}), "input"); + TopK(input, 10); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSERT_OK_AND_ASSIGN( + auto module, HloModule::CreateFromProto(xla_computation.proto(), config)); + + constexpr char filecheck_pattern[] = R"( + CHECK: call void @__xla_cpu_runtime_TopKF32(i64 5, i64 100, i64 10, + )"; + + CpuAotCompilationOptions options{ + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace cpu +} // namespace xla