Support arbitrary many values in KeyValueSort on CPU backend.

PiperOrigin-RevId: 217398356
This commit is contained in:
Adrian Kuegel 2018-10-16 15:27:50 -07:00 committed by TensorFlower Gardener
parent 8c3d9ae5de
commit e4e19db364
3 changed files with 167 additions and 133 deletions

View File

@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
@ -493,53 +494,44 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
return Status::OK();
}
Status IrEmitter::HandleSort(HloInstruction* sort) {
Status IrEmitter::HandleSort(HloInstruction* hlo) {
const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
auto keys = sort->operand(0);
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
ShapeIndex keys_shape_index({});
ShapeIndex values_shape_index({});
if (values != nullptr) {
keys_shape_index = ShapeIndex({0});
values_shape_index = ShapeIndex({1});
}
auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
auto keys_destination_address =
EmitBufferPointer(keys_destination, keys->shape());
auto values_destination = GetAllocationSlice(*sort, values_shape_index);
llvm::Value* values_destination_address = nullptr;
Shape keys_shape = sort->keys()->shape();
std::vector<llvm::Value*> destination_addresses(sort->operand_count());
for (int64 i = 0; i < sort->operand_count(); ++i) {
ShapeIndex shape_index =
sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
const HloInstruction* operand = sort->operand(i);
// We assume that the layout of all involved operands and outputs is the
// same.
TF_RET_CHECK(
LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
// The sort is implemented in-place, therefore we first copy the operand
// buffer to the output buffer if they are not the same.
if (keys_destination != GetAllocationSlice(*keys)) {
int64 primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
auto source_buffer = GetEmittedValueFor(keys);
int64 keys_size = ByteSizeOf(keys->shape());
MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
source_buffer,
/*SrcAlign=*/primitive_type_size, keys_size);
}
if (values != nullptr) {
values_destination_address =
EmitBufferPointer(values_destination, values->shape());
if (values_destination != GetAllocationSlice(*values)) {
// The sort is implemented in-place, therefore we first copy the operand
// buffer to the output buffer if they are not the same.
auto destination_buffer = GetAllocationSlice(*sort, shape_index);
destination_addresses[i] =
EmitBufferPointer(destination_buffer, operand->shape());
auto source_address = GetAllocationSlice(*operand);
if (destination_buffer != source_address) {
int64 primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
auto source_buffer = GetEmittedValueFor(values);
int64 values_size = ByteSizeOf(values->shape());
MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
auto source_buffer = GetEmittedValueFor(operand);
int64 size = ByteSizeOf(operand->shape());
MemCpy(destination_addresses[i], /*DstAlign=*/primitive_type_size,
source_buffer,
/*SrcAlign=*/primitive_type_size, values_size);
/*SrcAlign=*/primitive_type_size, size);
}
}
// Normalize the shape and the dimension to sort.
Shape normalized_keys_shape =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
keys->shape());
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
keys->shape().layout())[sort->dimensions(0)];
keys_shape.layout())[sort->sort_dimension()];
int64 sort_dimension_elements =
normalized_keys_shape.dimensions(physical_dimension_to_sort);
@ -553,7 +545,7 @@ Status IrEmitter::HandleSort(HloInstruction* sort) {
lower_dimensions *= normalized_keys_shape.dimensions(i);
}
PrimitiveType keys_type = keys->shape().element_type();
PrimitiveType keys_type = keys_shape.element_type();
const char* fn_name = nullptr;
llvm::Type* keys_native_type = nullptr;
switch (keys_type) {
@ -614,28 +606,49 @@ Status IrEmitter::HandleSort(HloInstruction* sort) {
llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
b_.getVoidTy(),
{keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
b_.getInt8PtrTy(), b_.getInt32Ty()},
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
b_.getInt32Ty()->getPointerTo()},
/*isVarArg=*/false);
auto* key_value_sort_func = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(fn_name, key_value_sort_type));
key_value_sort_func->setCallingConv(llvm::CallingConv::C);
key_value_sort_func->setDoesNotThrow();
key_value_sort_func->setOnlyAccessesArgMemory();
Call(key_value_sort_func,
{PointerCast(keys_destination_address, keys_native_type),
b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
b_.getInt64(lower_dimensions),
values != nullptr
? PointerCast(values_destination_address, b_.getInt8PtrTy())
: llvm::Constant::getNullValue(b_.getInt8PtrTy()),
b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
values->shape().element_type())
: 0)});
llvm::Value* values;
llvm::Value* sizes;
if (sort->values_count() == 0) {
values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo());
sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo());
} else {
values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt8PtrTy(), b_.getInt32(sort->values_count()),
"cc_values_alloca", &b_);
sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca",
&b_);
for (int64 i = 0; i < sort->values_count(); ++i) {
llvm::Value* value_as_i8ptr =
PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy());
llvm::Value* slot_in_values_alloca =
ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
Store(value_as_i8ptr, slot_in_values_alloca);
llvm::Value* slot_in_sizes_alloca =
ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
sort->operand(i + 1)->shape().element_type()));
Store(size, slot_in_sizes_alloca);
}
}
if (values != nullptr) {
llvm_ir::EmitTuple(GetIrArrayFor(sort),
{keys_destination_address, values_destination_address},
&b_, module_);
Call(key_value_sort_func,
{PointerCast(destination_addresses[0], keys_native_type),
b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
b_.getInt64(lower_dimensions), values,
b_.getInt32(sort->values_count()), sizes});
if (sort->values_count() > 0) {
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,
module_);
}
return Status::OK();
}

View File

@ -99,8 +99,9 @@ void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
}
template <typename KeyType>
void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values,
int32 values_count,
int32* values_primitive_type_size_in_bytes) {
// High-level idea of the iteration/sorting logic:
// Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
// dimension to sort, c is the product of the more minor dimensions (set to 1
@ -129,7 +130,7 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
index % sort_dimension_offset +
(index - index % sort_dimension_offset) * sort_dimension_elements;
// TODO(b/26783907): We could define a custom iterator class that references
// both arrays. Then we could avoid the intermediate copy. However this
// all arrays. Then we could avoid the intermediate copy. However this
// would become more complicated, and it is not clear if the benefit is high
// enough.
for (int64 i = 0; i < sort_dimension_elements; ++i) {
@ -140,97 +141,109 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
for (int64 i = 0; i < sort_dimension_elements; ++i) {
keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
}
if (values == nullptr) {
continue;
}
// Reorder the values according to the order defined by the keys.
for (int64 i = 0; i < sort_dimension_elements; ++i) {
int64 memory_index =
(base_offset + row_to_sort[i].second * sort_dimension_offset) *
values_primitive_type_size_in_bytes;
for (int32 idx = 0; idx < values_count; ++idx) {
for (int64 i = 0; i < sort_dimension_elements; ++i) {
int64 memory_index =
(base_offset + row_to_sort[i].second * sort_dimension_offset) *
values_primitive_type_size_in_bytes[idx];
reordered_values[i] = std::string(values + memory_index,
values_primitive_type_size_in_bytes);
}
for (int64 i = 0; i < sort_dimension_elements; ++i) {
int64 memory_index = (base_offset + i * sort_dimension_offset) *
values_primitive_type_size_in_bytes;
memcpy(values + memory_index, reordered_values[i].c_str(),
values_primitive_type_size_in_bytes);
reordered_values[i] =
std::string(values[idx] + memory_index,
values_primitive_type_size_in_bytes[idx]);
}
for (int64 i = 0; i < sort_dimension_elements; ++i) {
int64 memory_index = (base_offset + i * sort_dimension_offset) *
values_primitive_type_size_in_bytes[idx];
memcpy(values[idx] + memory_index, reordered_values[i].c_str(),
values_primitive_type_size_in_bytes[idx]);
}
}
}
}
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
bool* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
int8* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
uint8* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
int16* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
uint16* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
Eigen::half* keys, int64 a, int64 b, int64 c, char** values,
int32 values_count, int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
int32* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
uint32* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
float* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
int64* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
uint64* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
double* keys, int64 a, int64 b, int64 c, char* values,
int32 values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes) {
KeyValueSortImpl(keys, a, b, c, values, values_count,
values_primitive_type_size_in_bytes);
}

View File

@ -22,67 +22,75 @@ limitations under the License.
extern "C" {
// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
// If 'values' is not nullptr, the elements in 'values' are reordered in such a
// way that if the element at index 'i' in 'keys' was moved to index 'j', the
// element at index 'i' in 'values' is also moved to index 'j' (which means that
// the same elements correspond to each other as before).
// dimension of 'keys' is sorted into ascending order. If 'values_count' is <=
// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr.
// If 'values_count' > 0, they contain exactly 'values_count' many elements.
// Each element of 'values' also represents a 3-dimensional shape with
// dimensions [a, b, c], and the size of the primitive type of the i-th shape
// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in
// each 'values' shape are reordered in such a way that if the element at index
// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values'
// shape is also moved to index 'j' (which means that the same elements
// correspond to each other as before).
extern void __xla_cpu_runtime_KeyValueSortPRED(
bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS8(
tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU8(
tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS16(
tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU16(
tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortF16(
Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS32(
tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU32(
tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortF32(
float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortS64(
tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortU64(
tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
tensorflow::int64 c, char* values,
tensorflow::int32 values_primitive_type_size_in_bytes);
tensorflow::int64 c, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
extern void __xla_cpu_runtime_KeyValueSortF64(
double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes);
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_