[XLA:CPU] Allow C64 and C128 types in Sort().

These seem to have been omitted mostly as an oversight; the logic in Sort() doesn't seem to be data-type specific.

PiperOrigin-RevId: 311595522
Change-Id: I6264bbe6556a0823e8a88e2025c4886182aad6bf
This commit is contained in:
Peter Hawkins 2020-05-14 13:27:42 -07:00 committed by TensorFlower Gardener
parent 45d18ddb7e
commit 66769844a5

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
@ -570,24 +571,8 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
Shape keys_shape = sort->keys()->shape(); Shape keys_shape = sort->keys()->shape();
PrimitiveType keys_type = keys_shape.element_type(); PrimitiveType keys_type = keys_shape.element_type();
switch (keys_type) { if (!primitive_util::IsArrayType(keys_type)) {
case PRED: return Unimplemented("Element type %s not supported in the Sort op on CPU.",
case S8:
case U8:
case S16:
case U16:
case BF16:
case F16:
case S32:
case U32:
case F32:
case S64:
case U64:
case F64:
break;
default:
return Unimplemented(
"Element type %s not supported in the Sort op on CPU.",
PrimitiveType_Name(keys_type)); PrimitiveType_Name(keys_type));
} }
std::vector<llvm::Value*> destination_addresses(sort->operand_count()); std::vector<llvm::Value*> destination_addresses(sort->operand_count());