Benjamin Kramer 926642ea7a [XLA:CPU] Add a runtime function for F32 TopK and use TopkRewriter to target it
This just delegates the hard work to std::partial_sort.

PiperOrigin-RevId: 324979038
Change-Id: I16a7c4d840948f3744f4f920bac29d4e6d7333b3
2020-08-05 02:20:45 -07:00

77 lines
3.1 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
#include <algorithm>
#include <memory>
#include <numeric>
#include <vector>
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
template <typename T>
static void TopK(tensorflow::int64 batch_size, tensorflow::int64 input_size,
tensorflow::int64 k, const T* values, T* out_values,
tensorflow::int32* out_indices) {
// 'values' is managed by the JIT code, so msan can't tell they are
// initialized.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(values,
input_size * batch_size * sizeof(T));
std::vector<tensorflow::int32> temp_indices(input_size);
for (tensorflow::int64 batch = 0; batch != batch_size; ++batch) {
std::iota(temp_indices.begin(), temp_indices.end(), 0);
const T* values_batch = values + batch * input_size;
auto convert_to_int = [](T value) {
tensorflow::uint32 x;
std::memcpy(&x, &value, sizeof(x));
return static_cast<tensorflow::int32>(x) < 0
? std::numeric_limits<tensorflow::int32>::max() - x
: x;
};
auto kth_element = temp_indices.begin() + k;
std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(),
[&](size_t i1, size_t i2) {
// Do the comparison in integers to enforce a total
// order of -NaN < -Inf < -0 < +0 < +Inf < +NaN.
tensorflow::int32 v1 = convert_to_int(values_batch[i1]);
tensorflow::int32 v2 = convert_to_int(values_batch[i2]);
if (v1 == v2) {
return i1 < i2; // Stabilize sorting.
}
return v1 > v2;
});
T* out_values_batch = out_values + batch * k;
tensorflow::int32* out_indices_batch = out_indices + batch * k;
std::copy(temp_indices.begin(), kth_element, out_indices_batch);
for (tensorflow::int64 i = 0; i < k; i++) {
out_values_batch[i] = values_batch[temp_indices[i]];
}
}
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32(
tensorflow::int64 batch_size, tensorflow::int64 input_size,
tensorflow::int64 k, const float* values, float* out_values,
tensorflow::int32* out_indices) {
TopK(batch_size, input_size, k, values, out_values, out_indices);
}