Support arbitrary many values in KeyValueSort on CPU backend.
PiperOrigin-RevId: 217398356
This commit is contained in:
parent
8c3d9ae5de
commit
e4e19db364
@ -54,6 +54,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
|
||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||
@ -493,53 +494,44 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||
Status IrEmitter::HandleSort(HloInstruction* hlo) {
|
||||
const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
|
||||
auto keys = sort->operand(0);
|
||||
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
|
||||
ShapeIndex keys_shape_index({});
|
||||
ShapeIndex values_shape_index({});
|
||||
if (values != nullptr) {
|
||||
keys_shape_index = ShapeIndex({0});
|
||||
values_shape_index = ShapeIndex({1});
|
||||
}
|
||||
auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
|
||||
auto keys_destination_address =
|
||||
EmitBufferPointer(keys_destination, keys->shape());
|
||||
auto values_destination = GetAllocationSlice(*sort, values_shape_index);
|
||||
llvm::Value* values_destination_address = nullptr;
|
||||
Shape keys_shape = sort->keys()->shape();
|
||||
std::vector<llvm::Value*> destination_addresses(sort->operand_count());
|
||||
for (int64 i = 0; i < sort->operand_count(); ++i) {
|
||||
ShapeIndex shape_index =
|
||||
sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
|
||||
const HloInstruction* operand = sort->operand(i);
|
||||
// We assume that the layout of all involved operands and outputs is the
|
||||
// same.
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
|
||||
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
|
||||
keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
|
||||
|
||||
// The sort is implemented in-place, therefore we first copy the operand
|
||||
// buffer to the output buffer if they are not the same.
|
||||
if (keys_destination != GetAllocationSlice(*keys)) {
|
||||
int64 primitive_type_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
|
||||
auto source_buffer = GetEmittedValueFor(keys);
|
||||
int64 keys_size = ByteSizeOf(keys->shape());
|
||||
MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
|
||||
source_buffer,
|
||||
/*SrcAlign=*/primitive_type_size, keys_size);
|
||||
}
|
||||
if (values != nullptr) {
|
||||
values_destination_address =
|
||||
EmitBufferPointer(values_destination, values->shape());
|
||||
if (values_destination != GetAllocationSlice(*values)) {
|
||||
// The sort is implemented in-place, therefore we first copy the operand
|
||||
// buffer to the output buffer if they are not the same.
|
||||
auto destination_buffer = GetAllocationSlice(*sort, shape_index);
|
||||
destination_addresses[i] =
|
||||
EmitBufferPointer(destination_buffer, operand->shape());
|
||||
auto source_address = GetAllocationSlice(*operand);
|
||||
if (destination_buffer != source_address) {
|
||||
int64 primitive_type_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
|
||||
auto source_buffer = GetEmittedValueFor(values);
|
||||
int64 values_size = ByteSizeOf(values->shape());
|
||||
MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
|
||||
auto source_buffer = GetEmittedValueFor(operand);
|
||||
int64 size = ByteSizeOf(operand->shape());
|
||||
MemCpy(destination_addresses[i], /*DstAlign=*/primitive_type_size,
|
||||
source_buffer,
|
||||
/*SrcAlign=*/primitive_type_size, values_size);
|
||||
/*SrcAlign=*/primitive_type_size, size);
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize the shape and the dimension to sort.
|
||||
Shape normalized_keys_shape =
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
keys->shape());
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
|
||||
int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
|
||||
keys->shape().layout())[sort->dimensions(0)];
|
||||
keys_shape.layout())[sort->sort_dimension()];
|
||||
|
||||
int64 sort_dimension_elements =
|
||||
normalized_keys_shape.dimensions(physical_dimension_to_sort);
|
||||
@ -553,7 +545,7 @@ Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||
lower_dimensions *= normalized_keys_shape.dimensions(i);
|
||||
}
|
||||
|
||||
PrimitiveType keys_type = keys->shape().element_type();
|
||||
PrimitiveType keys_type = keys_shape.element_type();
|
||||
const char* fn_name = nullptr;
|
||||
llvm::Type* keys_native_type = nullptr;
|
||||
switch (keys_type) {
|
||||
@ -614,28 +606,49 @@ Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||
llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
|
||||
b_.getVoidTy(),
|
||||
{keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
|
||||
b_.getInt8PtrTy(), b_.getInt32Ty()},
|
||||
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
|
||||
b_.getInt32Ty()->getPointerTo()},
|
||||
/*isVarArg=*/false);
|
||||
auto* key_value_sort_func = llvm::cast<llvm::Function>(
|
||||
module_->getOrInsertFunction(fn_name, key_value_sort_type));
|
||||
key_value_sort_func->setCallingConv(llvm::CallingConv::C);
|
||||
key_value_sort_func->setDoesNotThrow();
|
||||
key_value_sort_func->setOnlyAccessesArgMemory();
|
||||
Call(key_value_sort_func,
|
||||
{PointerCast(keys_destination_address, keys_native_type),
|
||||
b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
|
||||
b_.getInt64(lower_dimensions),
|
||||
values != nullptr
|
||||
? PointerCast(values_destination_address, b_.getInt8PtrTy())
|
||||
: llvm::Constant::getNullValue(b_.getInt8PtrTy()),
|
||||
b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
values->shape().element_type())
|
||||
: 0)});
|
||||
llvm::Value* values;
|
||||
llvm::Value* sizes;
|
||||
if (sort->values_count() == 0) {
|
||||
values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo());
|
||||
sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo());
|
||||
} else {
|
||||
values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
b_.getInt8PtrTy(), b_.getInt32(sort->values_count()),
|
||||
"cc_values_alloca", &b_);
|
||||
sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca",
|
||||
&b_);
|
||||
for (int64 i = 0; i < sort->values_count(); ++i) {
|
||||
llvm::Value* value_as_i8ptr =
|
||||
PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy());
|
||||
llvm::Value* slot_in_values_alloca =
|
||||
ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
|
||||
Store(value_as_i8ptr, slot_in_values_alloca);
|
||||
llvm::Value* slot_in_sizes_alloca =
|
||||
ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
|
||||
llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
sort->operand(i + 1)->shape().element_type()));
|
||||
Store(size, slot_in_sizes_alloca);
|
||||
}
|
||||
}
|
||||
|
||||
if (values != nullptr) {
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(sort),
|
||||
{keys_destination_address, values_destination_address},
|
||||
&b_, module_);
|
||||
Call(key_value_sort_func,
|
||||
{PointerCast(destination_addresses[0], keys_native_type),
|
||||
b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
|
||||
b_.getInt64(lower_dimensions), values,
|
||||
b_.getInt32(sort->values_count()), sizes});
|
||||
|
||||
if (sort->values_count() > 0) {
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,
|
||||
module_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -99,8 +99,9 @@ void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
|
||||
}
|
||||
|
||||
template <typename KeyType>
|
||||
void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
|
||||
int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
// High-level idea of the iteration/sorting logic:
|
||||
// Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
|
||||
// dimension to sort, c is the product of the more minor dimensions (set to 1
|
||||
@ -129,7 +130,7 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
|
||||
index % sort_dimension_offset +
|
||||
(index - index % sort_dimension_offset) * sort_dimension_elements;
|
||||
// TODO(b/26783907): We could define a custom iterator class that references
|
||||
// both arrays. Then we could avoid the intermediate copy. However this
|
||||
// all arrays. Then we could avoid the intermediate copy. However this
|
||||
// would become more complicated, and it is not clear if the benefit is high
|
||||
// enough.
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
@ -140,97 +141,109 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
|
||||
}
|
||||
if (values == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reorder the values according to the order defined by the keys.
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
int64 memory_index =
|
||||
(base_offset + row_to_sort[i].second * sort_dimension_offset) *
|
||||
values_primitive_type_size_in_bytes;
|
||||
for (int32 idx = 0; idx < values_count; ++idx) {
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
int64 memory_index =
|
||||
(base_offset + row_to_sort[i].second * sort_dimension_offset) *
|
||||
values_primitive_type_size_in_bytes[idx];
|
||||
|
||||
reordered_values[i] = std::string(values + memory_index,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
int64 memory_index = (base_offset + i * sort_dimension_offset) *
|
||||
values_primitive_type_size_in_bytes;
|
||||
memcpy(values + memory_index, reordered_values[i].c_str(),
|
||||
values_primitive_type_size_in_bytes);
|
||||
reordered_values[i] =
|
||||
std::string(values[idx] + memory_index,
|
||||
values_primitive_type_size_in_bytes[idx]);
|
||||
}
|
||||
for (int64 i = 0; i < sort_dimension_elements; ++i) {
|
||||
int64 memory_index = (base_offset + i * sort_dimension_offset) *
|
||||
values_primitive_type_size_in_bytes[idx];
|
||||
memcpy(values[idx] + memory_index, reordered_values[i].c_str(),
|
||||
values_primitive_type_size_in_bytes[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
|
||||
bool* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
|
||||
int8* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
|
||||
uint8* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
|
||||
int16* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
|
||||
uint16* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
|
||||
Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
Eigen::half* keys, int64 a, int64 b, int64 c, char** values,
|
||||
int32 values_count, int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
|
||||
int32* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
|
||||
uint32* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
|
||||
float* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
|
||||
int64* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
|
||||
uint64* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
|
||||
double* keys, int64 a, int64 b, int64 c, char* values,
|
||||
int32 values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
|
||||
double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||
int32* values_primitive_type_size_in_bytes) {
|
||||
KeyValueSortImpl(keys, a, b, c, values, values_count,
|
||||
values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
@ -22,67 +22,75 @@ limitations under the License.
|
||||
extern "C" {
|
||||
|
||||
// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
|
||||
// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
|
||||
// If 'values' is not nullptr, the elements in 'values' are reordered in such a
|
||||
// way that if the element at index 'i' in 'keys' was moved to index 'j', the
|
||||
// element at index 'i' in 'values' is also moved to index 'j' (which means that
|
||||
// the same elements correspond to each other as before).
|
||||
// dimension of 'keys' is sorted into ascending order. If 'values_count' is <=
|
||||
// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr.
|
||||
// If 'values_count' > 0, they contain exactly 'values_count' many elements.
|
||||
// Each element of 'values' also represents a 3-dimensional shape with
|
||||
// dimensions [a, b, c], and the size of the primitive type of the i-th shape
|
||||
// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in
|
||||
// each 'values' shape are reordered in such a way that if the element at index
|
||||
// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values'
|
||||
// shape is also moved to index 'j' (which means that the same elements
|
||||
// correspond to each other as before).
|
||||
extern void __xla_cpu_runtime_KeyValueSortPRED(
|
||||
bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||
char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS8(
|
||||
tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU8(
|
||||
tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS16(
|
||||
tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU16(
|
||||
tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortF16(
|
||||
Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS32(
|
||||
tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU32(
|
||||
tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortF32(
|
||||
float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||
char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortS64(
|
||||
tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortU64(
|
||||
tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
|
||||
tensorflow::int64 c, char* values,
|
||||
tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
|
||||
extern void __xla_cpu_runtime_KeyValueSortF64(
|
||||
double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||
char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
|
||||
char** values, tensorflow::int32 values_count,
|
||||
tensorflow::int32* values_primitive_type_size_in_bytes);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user