Pass a compare function to the KeyValueSort runtime function.

This is in preparation of supporting calling a HloComputation for comparisons.

PiperOrigin-RevId: 229924424
This commit is contained in:
Adrian Kuegel 2019-01-18 07:13:43 -08:00 committed by TensorFlower Gardener
parent c861dc1dcf
commit faad607b60
6 changed files with 139 additions and 380 deletions

View File

@ -84,31 +84,8 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
extern const char* const kKeyValueSortPREDSymbolName =
"__xla_cpu_runtime_KeyValueSortPRED";
extern const char* const kKeyValueSortS8SymbolName =
"__xla_cpu_runtime_KeyValueSortS8";
extern const char* const kKeyValueSortU8SymbolName =
"__xla_cpu_runtime_KeyValueSortU8";
extern const char* const kKeyValueSortS16SymbolName =
"__xla_cpu_runtime_KeyValueSortS16";
extern const char* const kKeyValueSortU16SymbolName =
"__xla_cpu_runtime_KeyValueSortU16";
extern const char* const kKeyValueSortF16SymbolName =
"__xla_cpu_runtime_KeyValueSortF16";
extern const char* const kKeyValueSortS32SymbolName =
"__xla_cpu_runtime_KeyValueSortS32";
extern const char* const kKeyValueSortU32SymbolName =
"__xla_cpu_runtime_KeyValueSortU32";
extern const char* const kKeyValueSortF32SymbolName =
"__xla_cpu_runtime_KeyValueSortF32";
extern const char* const kKeyValueSortS64SymbolName =
"__xla_cpu_runtime_KeyValueSortS64";
extern const char* const kKeyValueSortU64SymbolName =
"__xla_cpu_runtime_KeyValueSortU64";
extern const char* const kKeyValueSortF64SymbolName =
"__xla_cpu_runtime_KeyValueSortF64";
extern const char* const kKeyValueSortSymbolName =
"__xla_cpu_runtime_KeyValueSort";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
} // namespace cpu

View File

@ -64,18 +64,7 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
extern const char* const kKeyValueSortPREDSymbolName;
extern const char* const kKeyValueSortS8SymbolName;
extern const char* const kKeyValueSortU8SymbolName;
extern const char* const kKeyValueSortS16SymbolName;
extern const char* const kKeyValueSortU16SymbolName;
extern const char* const kKeyValueSortF16SymbolName;
extern const char* const kKeyValueSortS32SymbolName;
extern const char* const kKeyValueSortU32SymbolName;
extern const char* const kKeyValueSortF32SymbolName;
extern const char* const kKeyValueSortS64SymbolName;
extern const char* const kKeyValueSortU64SymbolName;
extern const char* const kKeyValueSortF64SymbolName;
extern const char* const kKeyValueSortSymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.

View File

@ -495,6 +495,26 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
Shape keys_shape = sort->keys()->shape();
PrimitiveType keys_type = keys_shape.element_type();
switch (keys_type) {
case PRED:
case S8:
case U8:
case S16:
case U16:
case F16:
case S32:
case U32:
case F32:
case S64:
case U64:
case F64:
break;
default:
return Unimplemented(
"Element type %s not supported in the Sort op on CPU.",
PrimitiveType_Name(keys_type));
}
std::vector<llvm::Value*> destination_addresses(sort->operand_count());
for (int64 i = 0; i < sort->operand_count(); ++i) {
ShapeIndex shape_index =
@ -542,105 +562,101 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
lower_dimensions *= normalized_keys_shape.dimensions(i);
}
PrimitiveType keys_type = keys_shape.element_type();
const char* fn_name = nullptr;
llvm::Type* keys_native_type = nullptr;
switch (keys_type) {
case PRED:
fn_name = runtime::kKeyValueSortPREDSymbolName;
keys_native_type = b_.getInt8PtrTy();
break;
case S8:
fn_name = runtime::kKeyValueSortS8SymbolName;
keys_native_type = b_.getInt8PtrTy();
break;
case U8:
fn_name = runtime::kKeyValueSortU8SymbolName;
keys_native_type = b_.getInt8PtrTy();
break;
case S16:
fn_name = runtime::kKeyValueSortS16SymbolName;
keys_native_type = b_.getInt16Ty()->getPointerTo();
break;
case U16:
fn_name = runtime::kKeyValueSortU16SymbolName;
keys_native_type = b_.getInt16Ty()->getPointerTo();
break;
case F16:
fn_name = runtime::kKeyValueSortF16SymbolName;
keys_native_type = b_.getHalfTy()->getPointerTo();
break;
case S32:
fn_name = runtime::kKeyValueSortS32SymbolName;
keys_native_type = b_.getInt32Ty()->getPointerTo();
break;
case U32:
fn_name = runtime::kKeyValueSortU32SymbolName;
keys_native_type = b_.getInt32Ty()->getPointerTo();
break;
case F32:
fn_name = runtime::kKeyValueSortF32SymbolName;
keys_native_type = b_.getFloatTy()->getPointerTo();
break;
case S64:
fn_name = runtime::kKeyValueSortS64SymbolName;
keys_native_type = b_.getInt64Ty()->getPointerTo();
break;
case U64:
fn_name = runtime::kKeyValueSortU64SymbolName;
keys_native_type = b_.getInt64Ty()->getPointerTo();
break;
case F64:
fn_name = runtime::kKeyValueSortF64SymbolName;
keys_native_type = b_.getDoubleTy()->getPointerTo();
break;
default:
return Unimplemented(
"Element type %s not supported in the Sort op on CPU.",
PrimitiveType_Name(keys_type));
llvm::FunctionType* less_than_type = llvm::FunctionType::get(
b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()},
/*isVarArg=*/false);
auto less_than_function = llvm_ir::CreateFunction(
less_than_type, llvm::GlobalValue::InternalLinkage,
/*enable_fast_math=*/false,
/*optimize_for_size=*/true, absl::StrCat(IrName(sort), "_comparator"),
module_);
// Emit the code for the less_than function.
{
llvm::IRBuilder<>::InsertPointGuard guard(b_);
auto* entry_bb =
llvm::BasicBlock::Create(b_.getContext(), "entry", less_than_function);
b_.SetInsertPoint(entry_bb);
auto keys_ir_type = llvm_ir::PrimitiveTypeToIrType(keys_type, module_);
CHECK_EQ(less_than_function->arg_size(), 2);
llvm::Value* keys_lhs_ptr = less_than_function->arg_begin();
keys_lhs_ptr = PointerCast(keys_lhs_ptr, keys_ir_type->getPointerTo());
llvm::Value* keys_rhs_ptr = less_than_function->arg_begin() + 1;
keys_rhs_ptr = PointerCast(keys_rhs_ptr, keys_ir_type->getPointerTo());
// TODO(b/122298745): Replace the custom compare logic with a call to the
// computation specified for the Sort op.
llvm::Value* keys_lhs = Load(keys_ir_type, keys_lhs_ptr);
llvm::Value* keys_rhs = Load(keys_ir_type, keys_rhs_ptr);
bool is_signed_comparison = true;
if (primitive_util::IsFloatingPointType(keys_type)) {
// We would like a total order of floating point numbers so that the
// sort has a predictable behavior in the presence of NaNs. Rather
// than using floating point comparison, we use the following trick:
// If f is a float, and
// x = bit_cast<int32>(f);
// y = x < 0 ? 0x7FFFFFFF - x : x;
// then y is ordered as an int32 such that finite values have the
// obvious order, -0 is ordered before 0, and -NaN and NaN appear at
// the beginning and end of the ordering.
auto k = b_.getInt(llvm::APInt::getSignedMaxValue(
keys_lhs->getType()->getPrimitiveSizeInBits()));
auto comparison_type = k->getType();
auto zero = llvm::ConstantInt::get(comparison_type, 0);
auto maybe_flip = [&](llvm::Value* v) {
return b_.CreateSelect(b_.CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero),
b_.CreateSub(k, v), v);
};
keys_lhs = b_.CreateBitCast(keys_lhs, comparison_type);
keys_rhs = b_.CreateBitCast(keys_rhs, comparison_type);
keys_lhs = maybe_flip(keys_lhs);
keys_rhs = maybe_flip(keys_rhs);
} else if (!primitive_util::IsSignedIntegralType(keys_type)) {
is_signed_comparison = false;
}
llvm::Value* result =
b_.CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT
: llvm::ICmpInst::ICMP_ULT,
keys_lhs, keys_rhs);
llvm::ReturnInst::Create(b_.getContext(),
/*retVal=*/result, entry_bb);
}
llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
b_.getVoidTy(),
{keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
{b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
b_.getInt32Ty()->getPointerTo()},
b_.getInt32Ty()->getPointerTo(), less_than_function->getType()},
/*isVarArg=*/false);
auto* key_value_sort_func = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(fn_name, key_value_sort_type));
auto* key_value_sort_func =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
runtime::kKeyValueSortSymbolName, key_value_sort_type));
key_value_sort_func->setCallingConv(llvm::CallingConv::C);
key_value_sort_func->setDoesNotThrow();
llvm::Value* values;
llvm::Value* sizes;
if (sort->values_count() == 0) {
values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo());
sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo());
} else {
values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt8PtrTy(), b_.getInt32(sort->values_count()),
"cc_values_alloca", &b_);
sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca",
&b_);
for (int64 i = 0; i < sort->values_count(); ++i) {
llvm::Value* value_as_i8ptr =
PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy());
llvm::Value* slot_in_values_alloca =
ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
Store(value_as_i8ptr, slot_in_values_alloca);
llvm::Value* slot_in_sizes_alloca =
ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
sort->operand(i + 1)->shape().element_type()));
Store(size, slot_in_sizes_alloca);
}
llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
&b_);
llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
&b_);
for (int64 i = 0; i < sort->operand_count(); ++i) {
llvm::Value* value_as_i8ptr =
PointerCast(destination_addresses[i], b_.getInt8PtrTy());
llvm::Value* slot_in_values_alloca =
ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
Store(value_as_i8ptr, slot_in_values_alloca);
llvm::Value* slot_in_sizes_alloca =
ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
sort->operand(i)->shape().element_type()));
Store(size, slot_in_sizes_alloca);
}
Call(key_value_sort_func,
{PointerCast(destination_addresses[0], keys_native_type),
b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
{b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
b_.getInt64(lower_dimensions), values,
b_.getInt32(sort->values_count()), sizes});
b_.getInt32(sort->operand_count()), sizes, less_than_function});
if (sort->values_count() > 0) {
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,

View File

@ -15,12 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/dynamic_annotations.h"
@ -28,80 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace {
using tensorflow::int16;
using tensorflow::int32;
using tensorflow::int64;
using tensorflow::int8;
using tensorflow::uint16;
using tensorflow::uint32;
using tensorflow::uint64;
using tensorflow::uint8;
} // namespace
template <typename KeyType>
void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
std::sort(row_to_sort, row_to_sort + num_elements);
}
// We would like a total order of floating point numbers so that the
// sort has a predictable behavior in the presence of NaNs. Rather
// than using floating point comparison, we use the following trick:
// If f is a float, and
// x = bit_cast<int32>(f);
// y = x < 0 ? 0x7FFFFFFF - x : x;
// then y is ordered as an int32 such that finite values have the
// obvious order, -0 is ordered before 0, and -NaN and NaN appear at
// the beginning and end of the ordering.
template <typename CastType, typename UnsignedCastType, typename KeyType>
CastType Convert(KeyType value) {
CastType casted_value;
memcpy(&casted_value, &value, sizeof(CastType));
if (casted_value < 0) {
return static_cast<UnsignedCastType>(std::numeric_limits<CastType>::max()) -
casted_value;
}
return casted_value;
}
template <typename CastType, typename UnsignedCastType, typename KeyType>
bool LessThan(KeyType lhs, KeyType rhs) {
return Convert<CastType, UnsignedCastType>(lhs) <
Convert<CastType, UnsignedCastType>(rhs);
}
template <>
void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
std::stable_sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<double, int64>& lhs,
const std::pair<double, int64>& rhs) -> bool {
return LessThan<int64, uint64>(lhs.first, rhs.first);
});
}
template <>
void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
std::stable_sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<float, int64>& lhs,
const std::pair<float, int64>& rhs) -> bool {
return LessThan<int32, uint32>(lhs.first, rhs.first);
});
}
template <>
void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
int64 num_elements) {
std::stable_sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<Eigen::half, int64>& lhs,
const std::pair<Eigen::half, int64>& rhs) -> bool {
return LessThan<int32, uint32>(
Eigen::half_impl::half_to_float(lhs.first),
Eigen::half_impl::half_to_float(rhs.first));
});
}
template <typename KeyType>
void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
int32 values_count,
int32* values_primitive_type_size_in_bytes) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes,
bool (*less_than)(char*, char*)) {
// 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT
// code, so msan can't tell they are initialized.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*));
@ -121,8 +53,8 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
int64 num_iteration_elements = a * c;
int64 sort_dimension_offset = c;
std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
new std::pair<KeyType, int64>[sort_dimension_elements]);
std::unique_ptr<int64[]> indices(new int64[sort_dimension_elements]);
std::iota(indices.get(), indices.get() + sort_dimension_elements, 0);
std::unique_ptr<std::string[]> reordered_values(
new std::string[sort_dimension_elements]);
for (int64 index = 0; index < num_iteration_elements; ++index) {
@ -135,24 +67,22 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
int64 base_offset =
index % sort_dimension_offset +
(index - index % sort_dimension_offset) * sort_dimension_elements;
// TODO(b/26783907): We could define a custom iterator class that references
// all arrays. Then we could avoid the intermediate copy. However this
// would become more complicated, and it is not clear if the benefit is high
// enough.
for (int64 i = 0; i < sort_dimension_elements; ++i) {
row_to_sort[i] =
std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
}
KeyValueSort(row_to_sort.get(), sort_dimension_elements);
for (int64 i = 0; i < sort_dimension_elements; ++i) {
keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
}
std::stable_sort(
indices.get(), indices.get() + sort_dimension_elements,
[&](int64 a, int64 b) {
int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) *
values_primitive_type_size_in_bytes[0];
int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) *
values_primitive_type_size_in_bytes[0];
return less_than(values[0] + memory_index_lhs,
values[0] + memory_index_rhs);
});
// Reorder the values according to the order defined by the keys.
// Reorder the values according to the order defined by 'indices'.
for (int32 idx = 0; idx < values_count; ++idx) {
for (int64 i = 0; i < sort_dimension_elements; ++i) {
int64 memory_index =
(base_offset + row_to_sort[i].second * sort_dimension_offset) *
(base_offset + indices[i] * sort_dimension_offset) *
values_primitive_type_size_in_bytes[idx];
reordered_values[i] =
@ -168,88 +98,3 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
}
}
}
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
Eigen::half* keys, int64 a, int64 b, int64 c, char** values,
int32 values_count, int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}

View File

@ -21,76 +21,19 @@ limitations under the License.
extern "C" {
// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
// dimension of 'keys' is sorted into ascending order. If 'values_count' is <=
// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr.
// If 'values_count' > 0, they contain exactly 'values_count' many elements.
// Each element of 'values' also represents a 3-dimensional shape with
// dimensions [a, b, c], and the size of the primitive type of the i-th shape
// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in
// each 'values' shape are reordered in such a way that if the element at index
// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values'
// shape is also moved to index 'j' (which means that the same elements
// correspond to each other as before).
extern void __xla_cpu_runtime_KeyValueSortPRED(
bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
// Each entry in 'values' represents a 3-dimensional shape with dimensions
// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending
// order according to the results of comparisons using the provided 'less_than'
// function. 'values_count' must be > 0 and specifies the number of entries in
// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive
// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]'
// bytes. The elements in each 'values' shape are reordered in the same way
// according to the comparisons using the first shape.
extern void __xla_cpu_runtime_KeyValueSort(
tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS8(
tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU8(
tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS16(
tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU16(
tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortF16(
Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS32(
tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU32(
tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortF32(
float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS64(
tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU64(
tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortF64(
double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
tensorflow::int32* values_primitive_type_size_in_bytes,
bool (*less_than)(char*, char*));
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_

View File

@ -240,18 +240,7 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));