[XLA:CPU] Add a runtime function for F32 TopK and use TopkRewriter to target it
This just delegates the hard work to std::partial_sort. PiperOrigin-RevId: 324979038 Change-Id: I16a7c4d840948f3744f4f920bac29d4e6d7333b3
This commit is contained in:
parent
f576f29e3d
commit
926642ea7a
@ -129,42 +129,35 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
|
||||
def testTopKZeros(self):
|
||||
"""Tests that positive and negative zeros sort correctly."""
|
||||
# Only bfloat16 is implemented.
|
||||
bfloat16 = dtypes.bfloat16.as_numpy_dtype
|
||||
if bfloat16 not in self.numeric_types:
|
||||
return
|
||||
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtypes.bfloat16)
|
||||
with self.test_scope():
|
||||
topk = nn_ops.top_k(p, k=4)
|
||||
results = sess.run(
|
||||
topk,
|
||||
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
|
||||
self.assertAllEqual(
|
||||
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
|
||||
self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
|
||||
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtype)
|
||||
with self.test_scope():
|
||||
topk = nn_ops.top_k(p, k=4)
|
||||
results = sess.run(
|
||||
topk,
|
||||
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=dtype)})
|
||||
self.assertAllEqual(np.array([3., 0., 0., 0.], dtype=dtype), results[0])
|
||||
self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
|
||||
|
||||
def testTopKInfinities(self):
|
||||
"""Tests that positive and negative infinity sort correctly."""
|
||||
# Only bfloat16 is implemented.
|
||||
bfloat16 = dtypes.bfloat16.as_numpy_dtype
|
||||
if bfloat16 not in self.numeric_types:
|
||||
return
|
||||
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtypes.bfloat16)
|
||||
with self.test_scope():
|
||||
topk = nn_ops.top_k(p, k=6)
|
||||
results = sess.run(topk, {
|
||||
p: np.array(
|
||||
[1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16)
|
||||
})
|
||||
self.assertAllEqual(
|
||||
np.array(
|
||||
[float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
|
||||
dtype=bfloat16), results[0])
|
||||
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
|
||||
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtype)
|
||||
with self.test_scope():
|
||||
topk = nn_ops.top_k(p, k=6)
|
||||
results = sess.run(topk, {
|
||||
p:
|
||||
np.array([1, 2, float("inf"), -float("inf"), -1, -2],
|
||||
dtype=dtype)
|
||||
})
|
||||
self.assertAllEqual(
|
||||
np.array([float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
|
||||
dtype=dtype), results[0])
|
||||
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
|
||||
|
||||
def testInTopK(self):
|
||||
supported_types = set([np.int32, np.int64])
|
||||
|
@ -49,6 +49,7 @@ filegroup(
|
||||
"runtime_single_threaded_conv2d.cc",
|
||||
"runtime_single_threaded_fft.cc",
|
||||
"runtime_single_threaded_matmul.cc",
|
||||
"runtime_topk.cc",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
@ -64,6 +65,7 @@ filegroup(
|
||||
"runtime_single_threaded_conv2d.h",
|
||||
"runtime_single_threaded_fft.h",
|
||||
"runtime_single_threaded_matmul.h",
|
||||
"runtime_topk.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
@ -134,6 +136,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:topk_rewriter",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
|
||||
@ -230,6 +233,7 @@ cc_library(
|
||||
":runtime_fft",
|
||||
":runtime_fork_join",
|
||||
":runtime_key_value_sort",
|
||||
":runtime_topk",
|
||||
":runtime_matmul",
|
||||
":runtime_matmul_mkl",
|
||||
":runtime_single_threaded_conv2d",
|
||||
@ -759,6 +763,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_topk",
|
||||
srcs = ["runtime_topk.cc"],
|
||||
hdrs = ["runtime_topk.h"],
|
||||
copts = runtime_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_fork_join",
|
||||
srcs = ["runtime_fork_join.cc"],
|
||||
|
@ -104,6 +104,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/slice_sinker.h"
|
||||
#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
|
||||
#include "tensorflow/compiler/xla/service/sort_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/topk_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/transpose_folding.h"
|
||||
#include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
|
||||
@ -320,6 +321,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pass.AddPass<HloConstantFolding>();
|
||||
pass.AddPass<ConditionalSimplifier>();
|
||||
}
|
||||
pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64) {
|
||||
return sort->operand(0)->shape().element_type() == F32;
|
||||
});
|
||||
pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
|
||||
pipeline.AddPass<TransposeFolding>(
|
||||
[&](const HloInstruction& dot,
|
||||
|
@ -117,6 +117,7 @@ extern const char* const kParallelForkJoinSymbolName =
|
||||
"__xla_cpu_runtime_ParallelForkJoin";
|
||||
extern const char* const kKeyValueSortSymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSort";
|
||||
extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
|
||||
extern const char* const kTracingStartSymbolName =
|
||||
"__xla_cpu_runtime_TracingStart";
|
||||
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
|
||||
|
@ -72,6 +72,7 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
|
||||
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
|
||||
extern const char* const kParallelForkJoinSymbolName;
|
||||
extern const char* const kKeyValueSortSymbolName;
|
||||
extern const char* const kTopKF32SymbolName;
|
||||
extern const char* const kAllReduceSymbolName;
|
||||
extern const char* const kCollectivePermuteSymbolName;
|
||||
extern const char* const kReplicaIdSymbolName;
|
||||
|
@ -2387,6 +2387,41 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleTopK(HloInstruction* hlo) {
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
|
||||
const HloInstruction* input = hlo->operand(0);
|
||||
int64 k = hlo->shape().tuple_shapes(0).dimensions(1);
|
||||
TF_RET_CHECK(input->shape().element_type() == F32);
|
||||
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
|
||||
hlo->shape().tuple_shapes(0).layout()));
|
||||
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
|
||||
hlo->shape().tuple_shapes(1).layout()));
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
|
||||
assignment_.GetUniqueSlice(hlo->operand(0), {}));
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
|
||||
assignment_.GetUniqueSlice(hlo, {0}));
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
|
||||
assignment_.GetUniqueSlice(hlo, {1}));
|
||||
llvm::Value* values_ptr =
|
||||
EmitBufferPointer(values_slice, hlo->operand(0)->shape());
|
||||
llvm::Value* out_values_ptr =
|
||||
EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
|
||||
llvm::Value* out_indices_ptr =
|
||||
EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
|
||||
EmitCallToFunc(runtime::kTopKF32SymbolName,
|
||||
{b_.getInt64(input->shape().dimensions(0)),
|
||||
b_.getInt64(input->shape().dimensions(1)), b_.getInt64(k),
|
||||
values_ptr, out_values_ptr, out_indices_ptr},
|
||||
b_.getVoidTy());
|
||||
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
|
||||
&b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
if (custom_call->custom_call_target() == "PadToStatic") {
|
||||
return HandlePadToStatic(custom_call);
|
||||
@ -2394,6 +2429,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
if (custom_call->custom_call_target() == "SliceToDynamic") {
|
||||
return HandleSliceToDynamic(custom_call);
|
||||
}
|
||||
if (custom_call->custom_call_target() == "TopK") {
|
||||
return HandleTopK(custom_call);
|
||||
}
|
||||
absl::Span<HloInstruction* const> operands(custom_call->operands());
|
||||
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
|
||||
llvm::AllocaInst* operands_alloca =
|
||||
|
@ -190,6 +190,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
private:
|
||||
Status HandleSliceToDynamic(HloInstruction* hlo);
|
||||
Status HandlePadToStatic(HloInstruction* hlo);
|
||||
Status HandleTopK(HloInstruction* hlo);
|
||||
Status HandleAllReduceSingleReplica(HloInstruction* crs);
|
||||
Status HandleAllReduceMultipleReplica(HloInstruction* crs);
|
||||
|
||||
|
76
tensorflow/compiler/xla/service/cpu/runtime_topk.cc
Normal file
76
tensorflow/compiler/xla/service/cpu/runtime_topk.cc
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
template <typename T>
|
||||
static void TopK(tensorflow::int64 batch_size, tensorflow::int64 input_size,
|
||||
tensorflow::int64 k, const T* values, T* out_values,
|
||||
tensorflow::int32* out_indices) {
|
||||
// 'values' is managed by the JIT code, so msan can't tell they are
|
||||
// initialized.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(values,
|
||||
input_size * batch_size * sizeof(T));
|
||||
|
||||
std::vector<tensorflow::int32> temp_indices(input_size);
|
||||
for (tensorflow::int64 batch = 0; batch != batch_size; ++batch) {
|
||||
std::iota(temp_indices.begin(), temp_indices.end(), 0);
|
||||
|
||||
const T* values_batch = values + batch * input_size;
|
||||
|
||||
auto convert_to_int = [](T value) {
|
||||
tensorflow::uint32 x;
|
||||
std::memcpy(&x, &value, sizeof(x));
|
||||
return static_cast<tensorflow::int32>(x) < 0
|
||||
? std::numeric_limits<tensorflow::int32>::max() - x
|
||||
: x;
|
||||
};
|
||||
|
||||
auto kth_element = temp_indices.begin() + k;
|
||||
std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(),
|
||||
[&](size_t i1, size_t i2) {
|
||||
// Do the comparison in integers to enforce a total
|
||||
// order of -NaN < -Inf < -0 < +0 < +Inf < +NaN.
|
||||
tensorflow::int32 v1 = convert_to_int(values_batch[i1]);
|
||||
tensorflow::int32 v2 = convert_to_int(values_batch[i2]);
|
||||
if (v1 == v2) {
|
||||
return i1 < i2; // Stabilize sorting.
|
||||
}
|
||||
return v1 > v2;
|
||||
});
|
||||
|
||||
T* out_values_batch = out_values + batch * k;
|
||||
tensorflow::int32* out_indices_batch = out_indices + batch * k;
|
||||
std::copy(temp_indices.begin(), kth_element, out_indices_batch);
|
||||
for (tensorflow::int64 i = 0; i < k; i++) {
|
||||
out_values_batch[i] = values_batch[temp_indices[i]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32(
|
||||
tensorflow::int64 batch_size, tensorflow::int64 input_size,
|
||||
tensorflow::int64 k, const float* values, float* out_values,
|
||||
tensorflow::int32* out_indices) {
|
||||
TopK(batch_size, input_size, k, values, out_values, out_indices);
|
||||
}
|
32
tensorflow/compiler/xla/service/cpu/runtime_topk.h
Normal file
32
tensorflow/compiler/xla/service/cpu/runtime_topk.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Calculates `batch_size` topk operations with `input_size` inputs each. The
|
||||
// outputs are written to `out_values` and `out_indices`.
|
||||
extern void __xla_cpu_runtime_TopKF32(tensorflow::int64 batch_size,
|
||||
tensorflow::int64 input_size,
|
||||
tensorflow::int64 k, const float* values,
|
||||
float* out_values,
|
||||
tensorflow::int32* out_indices);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -270,6 +271,7 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TopKF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
|
||||
|
||||
|
@ -253,6 +253,22 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_topk_test",
|
||||
srcs = ["cpu_topk_test.cc"],
|
||||
deps = [
|
||||
":cpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||
"//tensorflow/compiler/xla/service/cpu:test_header_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_vectorization_test",
|
||||
srcs = ["cpu_vectorization_test.cc"],
|
||||
|
59
tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc
Normal file
59
tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace {
|
||||
|
||||
using CpuTopKTest = CpuCodegenTest;
|
||||
|
||||
TEST_F(CpuTopKTest, CallRuntime) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaOp input =
|
||||
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {5, 100}), "input");
|
||||
TopK(input, 10);
|
||||
TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(ProgramShape program_shape,
|
||||
xla_computation.GetProgramShape());
|
||||
HloModuleConfig config(program_shape);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module, HloModule::CreateFromProto(xla_computation.proto(), config));
|
||||
|
||||
constexpr char filecheck_pattern[] = R"(
|
||||
CHECK: call void @__xla_cpu_runtime_TopKF32(i64 5, i64 100, i64 10,
|
||||
)";
|
||||
|
||||
CpuAotCompilationOptions options{
|
||||
/*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost,
|
||||
/*features=*/"",
|
||||
/*entry_point_name=*/"entry",
|
||||
/*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
|
||||
|
||||
CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
|
||||
/*match_optimized_ir=*/true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user