[XLA:CPU] Add a runtime function for F32 TopK and use TopkRewriter to target it

This just delegates the hard work to std::partial_sort.

PiperOrigin-RevId: 324979038
Change-Id: I16a7c4d840948f3744f4f920bac29d4e6d7333b3
This commit is contained in:
Benjamin Kramer 2020-08-05 02:05:57 -07:00 committed by TensorFlower Gardener
parent f576f29e3d
commit 926642ea7a
12 changed files with 273 additions and 33 deletions

View File

@ -129,42 +129,35 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
return
with self.session() as sess:
p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope():
topk = nn_ops.top_k(p, k=4)
results = sess.run(
topk,
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
self.assertAllEqual(
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
for dtype in supported_types.intersection(self.numeric_types):
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():
topk = nn_ops.top_k(p, k=4)
results = sess.run(
topk,
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=dtype)})
self.assertAllEqual(np.array([3., 0., 0., 0.], dtype=dtype), results[0])
self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
return
with self.session() as sess:
p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope():
topk = nn_ops.top_k(p, k=6)
results = sess.run(topk, {
p: np.array(
[1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16)
})
self.assertAllEqual(
np.array(
[float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
dtype=bfloat16), results[0])
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
for dtype in supported_types.intersection(self.numeric_types):
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():
topk = nn_ops.top_k(p, k=6)
results = sess.run(topk, {
p:
np.array([1, 2, float("inf"), -float("inf"), -1, -2],
dtype=dtype)
})
self.assertAllEqual(
np.array([float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
dtype=dtype), results[0])
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
def testInTopK(self):
supported_types = set([np.int32, np.int64])

View File

@ -49,6 +49,7 @@ filegroup(
"runtime_single_threaded_conv2d.cc",
"runtime_single_threaded_fft.cc",
"runtime_single_threaded_matmul.cc",
"runtime_topk.cc",
],
visibility = [":friends"],
)
@ -64,6 +65,7 @@ filegroup(
"runtime_single_threaded_conv2d.h",
"runtime_single_threaded_fft.h",
"runtime_single_threaded_matmul.h",
"runtime_topk.h",
],
visibility = [":friends"],
)
@ -134,6 +136,7 @@ cc_library(
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:topk_rewriter",
"//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
@ -230,6 +233,7 @@ cc_library(
":runtime_fft",
":runtime_fork_join",
":runtime_key_value_sort",
":runtime_topk",
":runtime_matmul",
":runtime_matmul_mkl",
":runtime_single_threaded_conv2d",
@ -759,6 +763,19 @@ cc_library(
],
)
cc_library(
name = "runtime_topk",
srcs = ["runtime_topk.cc"],
hdrs = ["runtime_topk.h"],
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "runtime_fork_join",
srcs = ["runtime_fork_join.cc"],

View File

@ -104,6 +104,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/slice_sinker.h"
#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
#include "tensorflow/compiler/xla/service/sort_simplifier.h"
#include "tensorflow/compiler/xla/service/topk_rewriter.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
#include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
@ -320,6 +321,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pass.AddPass<HloConstantFolding>();
pass.AddPass<ConditionalSimplifier>();
}
pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64) {
return sort->operand(0)->shape().element_type() == F32;
});
pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
pipeline.AddPass<TransposeFolding>(
[&](const HloInstruction& dot,

View File

@ -117,6 +117,7 @@ extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
extern const char* const kKeyValueSortSymbolName =
"__xla_cpu_runtime_KeyValueSort";
extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
extern const char* const kTracingStartSymbolName =
"__xla_cpu_runtime_TracingStart";
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";

View File

@ -72,6 +72,7 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
extern const char* const kKeyValueSortSymbolName;
extern const char* const kTopKF32SymbolName;
extern const char* const kAllReduceSymbolName;
extern const char* const kCollectivePermuteSymbolName;
extern const char* const kReplicaIdSymbolName;

View File

@ -2387,6 +2387,41 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
return Status::OK();
}
Status IrEmitter::HandleTopK(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
const HloInstruction* input = hlo->operand(0);
int64 k = hlo->shape().tuple_shapes(0).dimensions(1);
TF_RET_CHECK(input->shape().element_type() == F32);
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
hlo->shape().tuple_shapes(0).layout()));
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
hlo->shape().tuple_shapes(1).layout()));
TF_RET_CHECK(
LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
assignment_.GetUniqueSlice(hlo->operand(0), {}));
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
assignment_.GetUniqueSlice(hlo, {0}));
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
assignment_.GetUniqueSlice(hlo, {1}));
llvm::Value* values_ptr =
EmitBufferPointer(values_slice, hlo->operand(0)->shape());
llvm::Value* out_values_ptr =
EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
llvm::Value* out_indices_ptr =
EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
EmitCallToFunc(runtime::kTopKF32SymbolName,
{b_.getInt64(input->shape().dimensions(0)),
b_.getInt64(input->shape().dimensions(1)), b_.getInt64(k),
values_ptr, out_values_ptr, out_indices_ptr},
b_.getVoidTy());
llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
&b_);
return Status::OK();
}
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
if (custom_call->custom_call_target() == "PadToStatic") {
return HandlePadToStatic(custom_call);
@ -2394,6 +2429,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
if (custom_call->custom_call_target() == "SliceToDynamic") {
return HandleSliceToDynamic(custom_call);
}
if (custom_call->custom_call_target() == "TopK") {
return HandleTopK(custom_call);
}
absl::Span<HloInstruction* const> operands(custom_call->operands());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =

View File

@ -190,6 +190,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
private:
Status HandleSliceToDynamic(HloInstruction* hlo);
Status HandlePadToStatic(HloInstruction* hlo);
Status HandleTopK(HloInstruction* hlo);
Status HandleAllReduceSingleReplica(HloInstruction* crs);
Status HandleAllReduceMultipleReplica(HloInstruction* crs);

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
#include <algorithm>
#include <memory>
#include <numeric>
#include <vector>
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
template <typename T>
static void TopK(tensorflow::int64 batch_size, tensorflow::int64 input_size,
tensorflow::int64 k, const T* values, T* out_values,
tensorflow::int32* out_indices) {
// 'values' is managed by the JIT code, so msan can't tell they are
// initialized.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(values,
input_size * batch_size * sizeof(T));
std::vector<tensorflow::int32> temp_indices(input_size);
for (tensorflow::int64 batch = 0; batch != batch_size; ++batch) {
std::iota(temp_indices.begin(), temp_indices.end(), 0);
const T* values_batch = values + batch * input_size;
auto convert_to_int = [](T value) {
tensorflow::uint32 x;
std::memcpy(&x, &value, sizeof(x));
return static_cast<tensorflow::int32>(x) < 0
? std::numeric_limits<tensorflow::int32>::max() - x
: x;
};
auto kth_element = temp_indices.begin() + k;
std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(),
[&](size_t i1, size_t i2) {
// Do the comparison in integers to enforce a total
// order of -NaN < -Inf < -0 < +0 < +Inf < +NaN.
tensorflow::int32 v1 = convert_to_int(values_batch[i1]);
tensorflow::int32 v2 = convert_to_int(values_batch[i2]);
if (v1 == v2) {
return i1 < i2; // Stabilize sorting.
}
return v1 > v2;
});
T* out_values_batch = out_values + batch * k;
tensorflow::int32* out_indices_batch = out_indices + batch * k;
std::copy(temp_indices.begin(), kth_element, out_indices_batch);
for (tensorflow::int64 i = 0; i < k; i++) {
out_values_batch[i] = values_batch[temp_indices[i]];
}
}
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32(
tensorflow::int64 batch_size, tensorflow::int64 input_size,
tensorflow::int64 k, const float* values, float* out_values,
tensorflow::int32* out_indices) {
TopK(batch_size, input_size, k, values, out_values, out_indices);
}

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
#include "tensorflow/core/platform/types.h"
extern "C" {
// Calculates `batch_size` topk operations with `input_size` inputs each. The
// outputs are written to `out_values` and `out_indices`.
extern void __xla_cpu_runtime_TopKF32(tensorflow::int64 batch_size,
tensorflow::int64 input_size,
tensorflow::int64 k, const float* values,
float* out_values,
tensorflow::int32* out_indices);
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/types.h"
@ -270,6 +271,7 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
REGISTER_CPU_RUNTIME_SYMBOL(TopKF32);
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);

View File

@ -253,6 +253,22 @@ tf_cc_test(
],
)
tf_cc_test(
name = "cpu_topk_test",
srcs = ["cpu_topk_test.cc"],
deps = [
":cpu_codegen_test",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu:test_header_helper",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "cpu_vectorization_test",
srcs = ["cpu_vectorization_test.cc"],

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
namespace xla {
namespace cpu {
namespace {
using CpuTopKTest = CpuCodegenTest;
TEST_F(CpuTopKTest, CallRuntime) {
XlaBuilder builder(TestName());
XlaOp input =
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {5, 100}), "input");
TopK(input, 10);
TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSERT_OK_AND_ASSIGN(
auto module, HloModule::CreateFromProto(xla_computation.proto(), config));
constexpr char filecheck_pattern[] = R"(
CHECK: call void @__xla_cpu_runtime_TopKF32(i64 5, i64 100, i64 10,
)";
CpuAotCompilationOptions options{
/*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost,
/*features=*/"",
/*entry_point_name=*/"entry",
/*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
/*match_optimized_ir=*/true);
}
} // namespace
} // namespace cpu
} // namespace xla