Pass a compare function to the KeyValueSort runtime function.
This is in preparation of supporting calling a HloComputation for comparisons. PiperOrigin-RevId: 229924424
This commit is contained in:
parent
c861dc1dcf
commit
faad607b60
@ -84,31 +84,8 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
|
||||
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
|
||||
extern const char* const kParallelForkJoinSymbolName =
|
||||
"__xla_cpu_runtime_ParallelForkJoin";
|
||||
extern const char* const kKeyValueSortPREDSymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortPRED";
|
||||
extern const char* const kKeyValueSortS8SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortS8";
|
||||
extern const char* const kKeyValueSortU8SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortU8";
|
||||
extern const char* const kKeyValueSortS16SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortS16";
|
||||
extern const char* const kKeyValueSortU16SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortU16";
|
||||
extern const char* const kKeyValueSortF16SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortF16";
|
||||
extern const char* const kKeyValueSortS32SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortS32";
|
||||
extern const char* const kKeyValueSortU32SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortU32";
|
||||
extern const char* const kKeyValueSortF32SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortF32";
|
||||
extern const char* const kKeyValueSortS64SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortS64";
|
||||
extern const char* const kKeyValueSortU64SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortU64";
|
||||
extern const char* const kKeyValueSortF64SymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSortF64";
|
||||
|
||||
extern const char* const kKeyValueSortSymbolName =
|
||||
"__xla_cpu_runtime_KeyValueSort";
|
||||
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
|
@ -64,18 +64,7 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
|
||||
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
|
||||
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
|
||||
extern const char* const kParallelForkJoinSymbolName;
|
||||
extern const char* const kKeyValueSortPREDSymbolName;
|
||||
extern const char* const kKeyValueSortS8SymbolName;
|
||||
extern const char* const kKeyValueSortU8SymbolName;
|
||||
extern const char* const kKeyValueSortS16SymbolName;
|
||||
extern const char* const kKeyValueSortU16SymbolName;
|
||||
extern const char* const kKeyValueSortF16SymbolName;
|
||||
extern const char* const kKeyValueSortS32SymbolName;
|
||||
extern const char* const kKeyValueSortU32SymbolName;
|
||||
extern const char* const kKeyValueSortF32SymbolName;
|
||||
extern const char* const kKeyValueSortS64SymbolName;
|
||||
extern const char* const kKeyValueSortU64SymbolName;
|
||||
extern const char* const kKeyValueSortF64SymbolName;
|
||||
extern const char* const kKeyValueSortSymbolName;
|
||||
|
||||
// All symbol names for XLA CPU runtime functions need to start with this
|
||||
// prefix.
|
||||
|
@ -495,6 +495,26 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
|
||||
const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
|
||||
Shape keys_shape = sort->keys()->shape();
|
||||
PrimitiveType keys_type = keys_shape.element_type();
|
||||
switch (keys_type) {
|
||||
case PRED:
|
||||
case S8:
|
||||
case U8:
|
||||
case S16:
|
||||
case U16:
|
||||
case F16:
|
||||
case S32:
|
||||
case U32:
|
||||
case F32:
|
||||
case S64:
|
||||
case U64:
|
||||
case F64:
|
||||
break;
|
||||
default:
|
||||
return Unimplemented(
|
||||
"Element type %s not supported in the Sort op on CPU.",
|
||||
PrimitiveType_Name(keys_type));
|
||||
}
|
||||
std::vector<llvm::Value*> destination_addresses(sort->operand_count());
|
||||
for (int64 i = 0; i < sort->operand_count(); ++i) {
|
||||
ShapeIndex shape_index =
|
||||
@ -542,105 +562,101 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
|
||||
lower_dimensions *= normalized_keys_shape.dimensions(i);
|
||||
}
|
||||
|
||||
PrimitiveType keys_type = keys_shape.element_type();
|
||||
const char* fn_name = nullptr;
|
||||
llvm::Type* keys_native_type = nullptr;
|
||||
switch (keys_type) {
|
||||
case PRED:
|
||||
fn_name = runtime::kKeyValueSortPREDSymbolName;
|
||||
keys_native_type = b_.getInt8PtrTy();
|
||||
break;
|
||||
case S8:
|
||||
fn_name = runtime::kKeyValueSortS8SymbolName;
|
||||
keys_native_type = b_.getInt8PtrTy();
|
||||
break;
|
||||
case U8:
|
||||
fn_name = runtime::kKeyValueSortU8SymbolName;
|
||||
keys_native_type = b_.getInt8PtrTy();
|
||||
break;
|
||||
case S16:
|
||||
fn_name = runtime::kKeyValueSortS16SymbolName;
|
||||
keys_native_type = b_.getInt16Ty()->getPointerTo();
|
||||
break;
|
||||
case U16:
|
||||
fn_name = runtime::kKeyValueSortU16SymbolName;
|
||||
keys_native_type = b_.getInt16Ty()->getPointerTo();
|
||||
break;
|
||||
case F16:
|
||||
fn_name = runtime::kKeyValueSortF16SymbolName;
|
||||
keys_native_type = b_.getHalfTy()->getPointerTo();
|
||||
break;
|
||||
case S32:
|
||||
fn_name = runtime::kKeyValueSortS32SymbolName;
|
||||
keys_native_type = b_.getInt32Ty()->getPointerTo();
|
||||
break;
|
||||
case U32:
|
||||
fn_name = runtime::kKeyValueSortU32SymbolName;
|
||||
keys_native_type = b_.getInt32Ty()->getPointerTo();
|
||||
break;
|
||||
case F32:
|
||||
fn_name = runtime::kKeyValueSortF32SymbolName;
|
||||
keys_native_type = b_.getFloatTy()->getPointerTo();
|
||||
break;
|
||||
case S64:
|
||||
fn_name = runtime::kKeyValueSortS64SymbolName;
|
||||
keys_native_type = b_.getInt64Ty()->getPointerTo();
|
||||
break;
|
||||
case U64:
|
||||
fn_name = runtime::kKeyValueSortU64SymbolName;
|
||||
keys_native_type = b_.getInt64Ty()->getPointerTo();
|
||||
break;
|
||||
case F64:
|
||||
fn_name = runtime::kKeyValueSortF64SymbolName;
|
||||
keys_native_type = b_.getDoubleTy()->getPointerTo();
|
||||
break;
|
||||
default:
|
||||
return Unimplemented(
|
||||
"Element type %s not supported in the Sort op on CPU.",
|
||||
PrimitiveType_Name(keys_type));
|
||||
llvm::FunctionType* less_than_type = llvm::FunctionType::get(
|
||||
b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()},
|
||||
/*isVarArg=*/false);
|
||||
auto less_than_function = llvm_ir::CreateFunction(
|
||||
less_than_type, llvm::GlobalValue::InternalLinkage,
|
||||
/*enable_fast_math=*/false,
|
||||
/*optimize_for_size=*/true, absl::StrCat(IrName(sort), "_comparator"),
|
||||
module_);
|
||||
// Emit the code for the less_than function.
|
||||
{
|
||||
llvm::IRBuilder<>::InsertPointGuard guard(b_);
|
||||
|
||||
auto* entry_bb =
|
||||
llvm::BasicBlock::Create(b_.getContext(), "entry", less_than_function);
|
||||
|
||||
b_.SetInsertPoint(entry_bb);
|
||||
auto keys_ir_type = llvm_ir::PrimitiveTypeToIrType(keys_type, module_);
|
||||
CHECK_EQ(less_than_function->arg_size(), 2);
|
||||
llvm::Value* keys_lhs_ptr = less_than_function->arg_begin();
|
||||
keys_lhs_ptr = PointerCast(keys_lhs_ptr, keys_ir_type->getPointerTo());
|
||||
llvm::Value* keys_rhs_ptr = less_than_function->arg_begin() + 1;
|
||||
keys_rhs_ptr = PointerCast(keys_rhs_ptr, keys_ir_type->getPointerTo());
|
||||
|
||||
// TODO(b/122298745): Replace the custom compare logic with a call to the
|
||||
// computation specified for the Sort op.
|
||||
llvm::Value* keys_lhs = Load(keys_ir_type, keys_lhs_ptr);
|
||||
llvm::Value* keys_rhs = Load(keys_ir_type, keys_rhs_ptr);
|
||||
bool is_signed_comparison = true;
|
||||
if (primitive_util::IsFloatingPointType(keys_type)) {
|
||||
// We would like a total order of floating point numbers so that the
|
||||
// sort has a predictable behavior in the presence of NaNs. Rather
|
||||
// than using floating point comparison, we use the following trick:
|
||||
// If f is a float, and
|
||||
// x = bit_cast<int32>(f);
|
||||
// y = x < 0 ? 0x7FFFFFFF - x : x;
|
||||
// then y is ordered as an int32 such that finite values have the
|
||||
// obvious order, -0 is ordered before 0, and -NaN and NaN appear at
|
||||
// the beginning and end of the ordering.
|
||||
auto k = b_.getInt(llvm::APInt::getSignedMaxValue(
|
||||
keys_lhs->getType()->getPrimitiveSizeInBits()));
|
||||
auto comparison_type = k->getType();
|
||||
auto zero = llvm::ConstantInt::get(comparison_type, 0);
|
||||
auto maybe_flip = [&](llvm::Value* v) {
|
||||
return b_.CreateSelect(b_.CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero),
|
||||
b_.CreateSub(k, v), v);
|
||||
};
|
||||
keys_lhs = b_.CreateBitCast(keys_lhs, comparison_type);
|
||||
keys_rhs = b_.CreateBitCast(keys_rhs, comparison_type);
|
||||
keys_lhs = maybe_flip(keys_lhs);
|
||||
keys_rhs = maybe_flip(keys_rhs);
|
||||
} else if (!primitive_util::IsSignedIntegralType(keys_type)) {
|
||||
is_signed_comparison = false;
|
||||
}
|
||||
llvm::Value* result =
|
||||
b_.CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT
|
||||
: llvm::ICmpInst::ICMP_ULT,
|
||||
keys_lhs, keys_rhs);
|
||||
llvm::ReturnInst::Create(b_.getContext(),
|
||||
/*retVal=*/result, entry_bb);
|
||||
}
|
||||
|
||||
llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
|
||||
b_.getVoidTy(),
|
||||
{keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
|
||||
{b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
|
||||
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
|
||||
b_.getInt32Ty()->getPointerTo()},
|
||||
b_.getInt32Ty()->getPointerTo(), less_than_function->getType()},
|
||||
/*isVarArg=*/false);
|
||||
auto* key_value_sort_func = llvm::cast<llvm::Function>(
|
||||
module_->getOrInsertFunction(fn_name, key_value_sort_type));
|
||||
auto* key_value_sort_func =
|
||||
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
|
||||
runtime::kKeyValueSortSymbolName, key_value_sort_type));
|
||||
key_value_sort_func->setCallingConv(llvm::CallingConv::C);
|
||||
key_value_sort_func->setDoesNotThrow();
|
||||
llvm::Value* values;
|
||||
llvm::Value* sizes;
|
||||
if (sort->values_count() == 0) {
|
||||
values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo());
|
||||
sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo());
|
||||
} else {
|
||||
values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
b_.getInt8PtrTy(), b_.getInt32(sort->values_count()),
|
||||
"cc_values_alloca", &b_);
|
||||
sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca",
|
||||
&b_);
|
||||
for (int64 i = 0; i < sort->values_count(); ++i) {
|
||||
llvm::Value* value_as_i8ptr =
|
||||
PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy());
|
||||
llvm::Value* slot_in_values_alloca =
|
||||
ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
|
||||
Store(value_as_i8ptr, slot_in_values_alloca);
|
||||
llvm::Value* slot_in_sizes_alloca =
|
||||
ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
|
||||
llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
sort->operand(i + 1)->shape().element_type()));
|
||||
Store(size, slot_in_sizes_alloca);
|
||||
}
|
||||
llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
|
||||
&b_);
|
||||
llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
|
||||
&b_);
|
||||
for (int64 i = 0; i < sort->operand_count(); ++i) {
|
||||
llvm::Value* value_as_i8ptr =
|
||||
PointerCast(destination_addresses[i], b_.getInt8PtrTy());
|
||||
llvm::Value* slot_in_values_alloca =
|
||||
ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
|
||||
Store(value_as_i8ptr, slot_in_values_alloca);
|
||||
llvm::Value* slot_in_sizes_alloca =
|
||||
ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
|
||||
llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
sort->operand(i)->shape().element_type()));
|
||||
Store(size, slot_in_sizes_alloca);
|
||||
}
|
||||
|
||||
Call(key_value_sort_func,
|
||||
{PointerCast(destination_addresses[0], keys_native_type),
|
||||
b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
|
||||
{b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
|
||||
b_.getInt64(lower_dimensions), values,
|
||||
b_.getInt32(sort->values_count()), sizes});
|
||||
b_.getInt32(sort->operand_count()), sizes, less_than_function});
|
||||
|
||||
if (sort->values_count() > 0) {
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,
|
||||
|
@ -15,12 +15,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
@ -28,80 +26,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace {
|
||||
using tensorflow::int16;
|
||||
using tensorflow::int32;
|
||||
using tensorflow::int64;
|
||||
using tensorflow::int8;
|
||||
using tensorflow::uint16;
|
||||
using tensorflow::uint32;
|
||||
using tensorflow::uint64;
|
||||
using tensorflow::uint8;
|
||||
} // namespace
|
||||
|
||||
template <typename KeyType>
|
||||
void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
|
||||
std::sort(row_to_sort, row_to_sort + num_elements);
|
||||
}
|
||||
|
||||
// We would like a total order of floating point numbers so that the
|
||||
// sort has a predictable behavior in the presence of NaNs. Rather
|
||||
// than using floating point comparison, we use the following trick:
|
||||
// If f is a float, and
|
||||
// x = bit_cast<int32>(f);
|
||||
// y = x < 0 ? 0x7FFFFFFF - x : x;
|
||||
// then y is ordered as an int32 such that finite values have the
|
||||
// obvious order, -0 is ordered before 0, and -NaN and NaN appear at
|
||||
// the beginning and end of the ordering.
|
||||
template <typename CastType, typename UnsignedCastType, typename KeyType>
|
||||
CastType Convert(KeyType value) {
|
||||
CastType casted_value;
|
||||
memcpy(&casted_value, &value, sizeof(CastType));
|
||||
if (casted_value < 0) {
|
||||
return static_cast<UnsignedCastType>(std::numeric_limits<CastType>::max()) -
|
||||
casted_value;
|
||||
}
|
||||
return casted_value;
|
||||
}
|
||||
|
||||
template <typename CastType, typename UnsignedCastType, typename KeyType>
|
||||
bool LessThan(KeyType lhs, KeyType rhs) {
|
||||
return Convert<CastType, UnsignedCastType>(lhs) <
|
||||
Convert<CastType, UnsignedCastType>(rhs);
|
||||
}
|
||||
|
||||
template <>
|
||||
void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
|
||||
std::stable_sort(row_to_sort, row_to_sort + num_elements,
|
||||
[](const std::pair<double, int64>& lhs,
|
||||
const std::pair<double, int64>& rhs) -> bool {
|
||||
return LessThan<int64, uint64>(lhs.first, rhs.first);
|
||||
});
|
||||
}
|
||||
|
||||
template <>
|
||||
void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
|
||||
std::stable_sort(row_to_sort, row_to_sort + num_elements,
|
||||
[](const std::pair<float, int64>& lhs,
|
||||
const std::pair<float, int64>& rhs) -> bool {
|
||||
return LessThan<int32, uint32>(lhs.first, rhs.first);
|
||||
});
|
||||
}
|
||||
|
||||
template <>
|
||||
void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
|
||||
int64 num_elements) {
|
||||
std::stable_sort(row_to_sort, row_to_sort + num_elements,
|
||||
[](const std::pair<Eigen::half, int64>& lhs,
|
||||
const std::pair<Eigen::half, int64>& rhs) -> bool {
|
||||
return LessThan<int32, uint32>(
|
||||
Eigen::half_impl::half_to_float(lhs.first),
|
||||
Eigen::half_impl::half_to_float(rhs.first));
|
||||
});
|
||||
}
|
||||
|
||||
template <typename KeyType>
|
||||
void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
|
||||
int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
|
||||
int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes,
|
||||
bool (*less_than)(char*, char*)) {
|
||||
// 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT
|
||||
// code, so msan can't tell they are initialized.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*));
|
||||
@ -121,8 +53,8 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
|
||||
int64 num_iteration_elements = a * c;
|
||||
int64 sort_dimension_offset = c;
|
||||
|
||||
std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
|
||||
new std::pair<KeyType, int64>[sort_dimension_elements]);
|
||||
std::unique_ptr<int64[]> indices(new int64[sort_dimension_elements]);
|
||||
std::iota(indices.get(), indices.get() + sort_dimension_elements, 0);
|
||||
std::unique_ptr<std::string[]> reordered_values(
|
||||
new std::string[sort_dimension_elements]);
|
||||
for (int64 index = 0; index < num_iteration_elements; ++index) {
|
||||
@ -135,24 +67,22 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
|
||||
int64 base_offset =
|
||||
index % sort_dimension_offset +
|
||||
(index - index % sort_dimension_offset) * sort_dimension_elements;
|
||||
// TODO(b/26783907): We could define a custom iterator class that references
|
||||
// all arrays. Then we could avoid the intermediate copy. However this
|
||||
// would become more complicated, and it is not clear if the benefit is high
|
||||
// enough.
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
row_to_sort[i] =
|
||||
std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
|
||||
}
|
||||
KeyValueSort(row_to_sort.get(), sort_dimension_elements);
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
|
||||
}
|
||||
std::stable_sort(
|
||||
indices.get(), indices.get() + sort_dimension_elements,
|
||||
[&](int64 a, int64 b) {
|
||||
int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) *
|
||||
values_primitive_type_size_in_bytes[0];
|
||||
int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) *
|
||||
values_primitive_type_size_in_bytes[0];
|
||||
return less_than(values[0] + memory_index_lhs,
|
||||
values[0] + memory_index_rhs);
|
||||
});
|
||||
|
||||
// Reorder the values according to the order defined by the keys.
|
||||
// Reorder the values according to the order defined by 'indices'.
|
||||
for (int32 idx = 0; idx < values_count; ++idx) {
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
int64 memory_index =
|
||||
(base_offset + row_to_sort[i].second * sort_dimension_offset) *
|
||||
(base_offset + indices[i] * sort_dimension_offset) *
|
||||
values_primitive_type_size_in_bytes[idx];
|
||||
|
||||
reordered_values[i] =
|
||||
@ -168,88 +98,3 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
|
||||
bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
|
||||
int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
|
||||
uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
|
||||
int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
|
||||
uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
|
||||
Eigen::half* keys, int64 a, int64 b, int64 c, char** values,
|
||||
int32 values_count, int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
|
||||
int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
|
||||
uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
|
||||
float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
|
||||
int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
|
||||
uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
|
||||
double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
@ -21,76 +21,19 @@ limitations under the License.
|
||||
|
||||
extern "C" {
|
||||
|
||||
// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
|
||||
// dimension of 'keys' is sorted into ascending order. If 'values_count' is <=
|
||||
// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr.
|
||||
// If 'values_count' > 0, they contain exactly 'values_count' many elements.
|
||||
// Each element of 'values' also represents a 3-dimensional shape with
|
||||
// dimensions [a, b, c], and the size of the primitive type of the i-th shape
|
||||
// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in
|
||||
// each 'values' shape are reordered in such a way that if the element at index
|
||||
// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values'
|
||||
// shape is also moved to index 'j' (which means that the same elements
|
||||
// correspond to each other as before).
|
||||
extern void __xla_cpu_runtime_KeyValueSortPRED(
|
||||
bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||
// Each entry in 'values' represents a 3-dimensional shape with dimensions
|
||||
// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending
|
||||
// order according to the results of comparisons using the provided 'less_than'
|
||||
// function. 'values_count' must be > 0 and specifies the number of entries in
|
||||
// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive
|
||||
// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]'
|
||||
// bytes. The elements in each 'values' shape are reordered in the same way
|
||||
// according to the comparisons using the first shape.
|
||||
extern void __xla_cpu_runtime_KeyValueSort(
|
||||
tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||
char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS8(
|
||||
tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU8(
|
||||
tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS16(
|
||||
tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU16(
|
||||
tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortF16(
|
||||
Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS32(
|
||||
tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU32(
|
||||
tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortF32(
|
||||
float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||
char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS64(
|
||||
tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU64(
|
||||
tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortF64(
|
||||
double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||
char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes,
|
||||
bool (*less_than)(char*, char*));
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
|
||||
|
@ -240,18 +240,7 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
|
||||
|
||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
|
||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
|
||||
|
Loading…
Reference in New Issue
Block a user