diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index a9febe891b5..d8878e622c0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -84,31 +84,8 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; -extern const char* const kKeyValueSortPREDSymbolName = - "__xla_cpu_runtime_KeyValueSortPRED"; -extern const char* const kKeyValueSortS8SymbolName = - "__xla_cpu_runtime_KeyValueSortS8"; -extern const char* const kKeyValueSortU8SymbolName = - "__xla_cpu_runtime_KeyValueSortU8"; -extern const char* const kKeyValueSortS16SymbolName = - "__xla_cpu_runtime_KeyValueSortS16"; -extern const char* const kKeyValueSortU16SymbolName = - "__xla_cpu_runtime_KeyValueSortU16"; -extern const char* const kKeyValueSortF16SymbolName = - "__xla_cpu_runtime_KeyValueSortF16"; -extern const char* const kKeyValueSortS32SymbolName = - "__xla_cpu_runtime_KeyValueSortS32"; -extern const char* const kKeyValueSortU32SymbolName = - "__xla_cpu_runtime_KeyValueSortU32"; -extern const char* const kKeyValueSortF32SymbolName = - "__xla_cpu_runtime_KeyValueSortF32"; -extern const char* const kKeyValueSortS64SymbolName = - "__xla_cpu_runtime_KeyValueSortS64"; -extern const char* const kKeyValueSortU64SymbolName = - "__xla_cpu_runtime_KeyValueSortU64"; -extern const char* const kKeyValueSortF64SymbolName = - "__xla_cpu_runtime_KeyValueSortF64"; - +extern const char* const kKeyValueSortSymbolName = + "__xla_cpu_runtime_KeyValueSort"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index b2e760a224a..3a2b44d8c1a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -64,18 +64,7 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; -extern const char* const kKeyValueSortPREDSymbolName; -extern const char* const kKeyValueSortS8SymbolName; -extern const char* const kKeyValueSortU8SymbolName; -extern const char* const kKeyValueSortS16SymbolName; -extern const char* const kKeyValueSortU16SymbolName; -extern const char* const kKeyValueSortF16SymbolName; -extern const char* const kKeyValueSortS32SymbolName; -extern const char* const kKeyValueSortU32SymbolName; -extern const char* const kKeyValueSortF32SymbolName; -extern const char* const kKeyValueSortS64SymbolName; -extern const char* const kKeyValueSortU64SymbolName; -extern const char* const kKeyValueSortF64SymbolName; +extern const char* const kKeyValueSortSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index f8a997045a6..0226e8275cb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -495,6 +495,26 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { const HloSortInstruction* sort = Cast(hlo); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); + PrimitiveType keys_type = keys_shape.element_type(); + switch (keys_type) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case F16: + case S32: + case U32: + case F32: + case S64: + case U64: + case F64: + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = @@ -542,105 +562,101 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - PrimitiveType keys_type = keys_shape.element_type(); - const char* fn_name = nullptr; - llvm::Type* keys_native_type = nullptr; - switch (keys_type) { - case PRED: - fn_name = runtime::kKeyValueSortPREDSymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S8: - fn_name = runtime::kKeyValueSortS8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case U8: - fn_name = runtime::kKeyValueSortU8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S16: - fn_name = runtime::kKeyValueSortS16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case U16: - fn_name = runtime::kKeyValueSortU16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case F16: - fn_name = runtime::kKeyValueSortF16SymbolName; - keys_native_type = b_.getHalfTy()->getPointerTo(); - break; - case S32: - fn_name = runtime::kKeyValueSortS32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case U32: - fn_name = runtime::kKeyValueSortU32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case F32: - fn_name = runtime::kKeyValueSortF32SymbolName; - keys_native_type = b_.getFloatTy()->getPointerTo(); - break; - case S64: - fn_name = runtime::kKeyValueSortS64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case U64: - fn_name = runtime::kKeyValueSortU64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case F64: - fn_name = runtime::kKeyValueSortF64SymbolName; - keys_native_type = b_.getDoubleTy()->getPointerTo(); - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); + llvm::FunctionType* less_than_type = llvm::FunctionType::get( + b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()}, + /*isVarArg=*/false); + auto less_than_function = llvm_ir::CreateFunction( + less_than_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/false, + /*optimize_for_size=*/true, absl::StrCat(IrName(sort), "_comparator"), + module_); + // Emit the code for the less_than function. + { + llvm::IRBuilder<>::InsertPointGuard guard(b_); + + auto* entry_bb = + llvm::BasicBlock::Create(b_.getContext(), "entry", less_than_function); + + b_.SetInsertPoint(entry_bb); + auto keys_ir_type = llvm_ir::PrimitiveTypeToIrType(keys_type, module_); + CHECK_EQ(less_than_function->arg_size(), 2); + llvm::Value* keys_lhs_ptr = less_than_function->arg_begin(); + keys_lhs_ptr = PointerCast(keys_lhs_ptr, keys_ir_type->getPointerTo()); + llvm::Value* keys_rhs_ptr = less_than_function->arg_begin() + 1; + keys_rhs_ptr = PointerCast(keys_rhs_ptr, keys_ir_type->getPointerTo()); + + // TODO(b/122298745): Replace the custom compare logic with a call to the + // computation specified for the Sort op. + llvm::Value* keys_lhs = Load(keys_ir_type, keys_lhs_ptr); + llvm::Value* keys_rhs = Load(keys_ir_type, keys_rhs_ptr); + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(keys_type)) { + // We would like a total order of floating point numbers so that the + // sort has a predictable behavior in the presence of NaNs. Rather + // than using floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the + // obvious order, -0 is ordered before 0, and -NaN and NaN appear at + // the beginning and end of the ordering. + auto k = b_.getInt(llvm::APInt::getSignedMaxValue( + keys_lhs->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b_.CreateSelect(b_.CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b_.CreateSub(k, v), v); + }; + keys_lhs = b_.CreateBitCast(keys_lhs, comparison_type); + keys_rhs = b_.CreateBitCast(keys_rhs, comparison_type); + keys_lhs = maybe_flip(keys_lhs); + keys_rhs = maybe_flip(keys_rhs); + } else if (!primitive_util::IsSignedIntegralType(keys_type)) { + is_signed_comparison = false; + } + llvm::Value* result = + b_.CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + keys_lhs, keys_rhs); + llvm::ReturnInst::Create(b_.getContext(), + /*retVal=*/result, entry_bb); } llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( b_.getVoidTy(), - {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo()}, + b_.getInt32Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); - auto* key_value_sort_func = llvm::cast( - module_->getOrInsertFunction(fn_name, key_value_sort_type)); + auto* key_value_sort_func = + llvm::cast(module_->getOrInsertFunction( + runtime::kKeyValueSortSymbolName, key_value_sort_type)); key_value_sort_func->setCallingConv(llvm::CallingConv::C); key_value_sort_func->setDoesNotThrow(); - llvm::Value* values; - llvm::Value* sizes; - if (sort->values_count() == 0) { - values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()); - sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo()); - } else { - values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt8PtrTy(), b_.getInt32(sort->values_count()), - "cc_values_alloca", &b_); - sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca", - &b_); - for (int64 i = 0; i < sort->values_count(); ++i) { - llvm::Value* value_as_i8ptr = - PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy()); - llvm::Value* slot_in_values_alloca = - ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); - Store(value_as_i8ptr, slot_in_values_alloca); - llvm::Value* slot_in_sizes_alloca = - ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); - llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( - sort->operand(i + 1)->shape().element_type())); - Store(size, slot_in_sizes_alloca); - } + llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca", + &b_); + llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca", + &b_); + for (int64 i = 0; i < sort->operand_count(); ++i) { + llvm::Value* value_as_i8ptr = + PointerCast(destination_addresses[i], b_.getInt8PtrTy()); + llvm::Value* slot_in_values_alloca = + ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); + Store(value_as_i8ptr, slot_in_values_alloca); + llvm::Value* slot_in_sizes_alloca = + ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); + llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type())); + Store(size, slot_in_sizes_alloca); } Call(key_value_sort_func, - {PointerCast(destination_addresses[0], keys_native_type), - b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, - b_.getInt32(sort->values_count()), sizes}); + b_.getInt32(sort->operand_count()), sizes, less_than_function}); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 722aa3120ef..a0667d0d9d1 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include -#include #include -#include #include +#include #include -#include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/dynamic_annotations.h" @@ -28,80 +26,14 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace { -using tensorflow::int16; using tensorflow::int32; using tensorflow::int64; -using tensorflow::int8; -using tensorflow::uint16; -using tensorflow::uint32; -using tensorflow::uint64; -using tensorflow::uint8; +} // namespace -template -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements); -} - -// We would like a total order of floating point numbers so that the -// sort has a predictable behavior in the presence of NaNs. Rather -// than using floating point comparison, we use the following trick: -// If f is a float, and -// x = bit_cast(f); -// y = x < 0 ? 0x7FFFFFFF - x : x; -// then y is ordered as an int32 such that finite values have the -// obvious order, -0 is ordered before 0, and -NaN and NaN appear at -// the beginning and end of the ordering. -template -CastType Convert(KeyType value) { - CastType casted_value; - memcpy(&casted_value, &value, sizeof(CastType)); - if (casted_value < 0) { - return static_cast(std::numeric_limits::max()) - - casted_value; - } - return casted_value; -} - -template -bool LessThan(KeyType lhs, KeyType rhs) { - return Convert(lhs) < - Convert(rhs); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, - int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), - Eigen::half_impl::half_to_float(rhs.first)); - }); -} - -template -void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, - int32* values_primitive_type_size_in_bytes) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( + int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes, + bool (*less_than)(char*, char*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); @@ -121,8 +53,8 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 num_iteration_elements = a * c; int64 sort_dimension_offset = c; - std::unique_ptr[]> row_to_sort( - new std::pair[sort_dimension_elements]); + std::unique_ptr indices(new int64[sort_dimension_elements]); + std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); for (int64 index = 0; index < num_iteration_elements; ++index) { @@ -135,24 +67,22 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - // TODO(b/26783907): We could define a custom iterator class that references - // all arrays. Then we could avoid the intermediate copy. However this - // would become more complicated, and it is not clear if the benefit is high - // enough. - for (int64 i = 0; i < sort_dimension_elements; ++i) { - row_to_sort[i] = - std::make_pair(keys[base_offset + i * sort_dimension_offset], i); - } - KeyValueSort(row_to_sort.get(), sort_dimension_elements); - for (int64 i = 0; i < sort_dimension_elements; ++i) { - keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; - } + std::stable_sort( + indices.get(), indices.get() + sort_dimension_elements, + [&](int64 a, int64 b) { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + return less_than(values[0] + memory_index_lhs, + values[0] + memory_index_rhs); + }); - // Reorder the values according to the order defined by the keys. + // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { for (int64 i = 0; i < sort_dimension_elements; ++i) { int64 memory_index = - (base_offset + row_to_sort[i].second * sort_dimension_offset) * + (base_offset + indices[i] * sort_dimension_offset) * values_primitive_type_size_in_bytes[idx]; reordered_values[i] = @@ -168,88 +98,3 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, } } } -} // namespace - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( - int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( - uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( - int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( - uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( - int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( - uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( - float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( - int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( - uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( - double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 78210993869..5460af3485b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -21,76 +21,19 @@ limitations under the License. extern "C" { -// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' -// dimension of 'keys' is sorted into ascending order. If 'values_count' is <= -// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr. -// If 'values_count' > 0, they contain exactly 'values_count' many elements. -// Each element of 'values' also represents a 3-dimensional shape with -// dimensions [a, b, c], and the size of the primitive type of the i-th shape -// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in -// each 'values' shape are reordered in such a way that if the element at index -// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values' -// shape is also moved to index 'j' (which means that the same elements -// correspond to each other as before). -extern void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, +// Each entry in 'values' represents a 3-dimensional shape with dimensions +// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending +// order according to the results of comparisons using the provided 'less_than' +// function. 'values_count' must be > 0 and specifies the number of entries in +// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive +// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' +// bytes. The elements in each 'values' shape are reordered in the same way +// according to the comparisons using the first shape. +extern void __xla_cpu_runtime_KeyValueSort( + tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS8( - tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU8( - tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS16( - tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU16( - tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS32( - tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU32( - tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF32( - float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS64( - tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU64( - tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF64( - double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); + tensorflow::int32* values_primitive_type_size_in_bytes, + bool (*less_than)(char*, char*)); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 12ab6360c56..9c2685674fb 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -240,18 +240,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee));