Adding total-order comparison support in proto and HloInstruction.

Specifically a comparison type attribute is added to Hlo proto so that total
order comparison can be explicitly specified. A comparison expander pass is
added to all compilers to expand total order comparison into equivalent
implementations through type conversion.

PiperOrigin-RevId: 325820826
Change-Id: I7beceb2f751ddc0be7c6b7a74037e562e7580b62
This commit is contained in:
A. Unique TensorFlower 2020-08-10 09:32:46 -07:00 committed by TensorFlower Gardener
parent 9cd6a52394
commit fd87e24980
27 changed files with 472 additions and 149 deletions

View File

@ -55,9 +55,13 @@ xla_test(
cc_library(
name = "comparators",
srcs = ["comparators.cc"],
hdrs = ["comparators.h"],
hdrs = [
"comparators.h",
"//tensorflow/compiler/xla:literal_util",
],
deps = [
":constants",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -32,85 +32,13 @@ limitations under the License.
namespace xla {
namespace {
using XlaOpGenerator = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>);
XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
int64 bit_width) {
PrimitiveType signed_type;
PrimitiveType unsigned_type;
XlaOp max_value;
switch (bit_width) {
case 16:
max_value =
ConstantR0(value.builder(),
static_cast<uint16>(std::numeric_limits<int16>::max()));
signed_type = S16;
unsigned_type = U16;
break;
case 32:
max_value =
ConstantR0(value.builder(),
static_cast<uint32>(std::numeric_limits<int32>::max()));
signed_type = S32;
unsigned_type = U32;
break;
case 64:
max_value =
ConstantR0(value.builder(),
static_cast<uint64>(std::numeric_limits<int64>::max()));
signed_type = S64;
unsigned_type = U64;
break;
default:
return value.builder()->ReportError(
InvalidArgument("Invalid bit width %lld for Comparator floating "
"point parameter.",
bit_width));
}
// Switch from a floating point value to a integer value in such a way that
// when using the integer value to compare, we get the same result for normal
// values, and -Nan is treated as the smallest value, and Nan is treated as
// the largest value.
// If f is a float, and
// x = bit_cast<int32>(f);
// y = x < 0 ? numeric_limits<int32>::max() - x : x;
// then y is ordered as an int32 such that finite values have the obvious
// order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
// and end of the ordering.
// Note that in order to avoid -x to overflow, we calculate
// numeric_limits<int32>::max() - x as unsigned, and then convert back to
// signed.
auto signed_value = BitcastConvertType(value, signed_type);
auto unsigned_value = BitcastConvertType(value, unsigned_type);
auto flipped_value =
BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type));
return Select(is_negative, flipped_value, signed_value);
}
void ConvertFloatingPoint(const PrimitiveType& operand_type, XlaOp* lhs_param,
XlaOp* rhs_param) {
if (primitive_util::IsFloatingPointType(operand_type)) {
PrimitiveType compare_type = operand_type;
// Special-case handling for BF16. We currently do not support direct
// comparisons with BF16, so we convert to F32 and then use the F32
// comparison logic.
if (compare_type == BF16) {
compare_type = F32;
*lhs_param = ConvertElementType(*lhs_param, F32);
*rhs_param = ConvertElementType(*rhs_param, F32);
}
int64 bit_width = primitive_util::BitWidth(compare_type);
*lhs_param = BitcastConvertFloatingPointToIntegral(*lhs_param, bit_width);
*rhs_param = BitcastConvertFloatingPointToIntegral(*rhs_param, bit_width);
}
}
using XlaCompareOp = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>);
XlaComputation CreateScalarComparisonComputation(
const string& name, const std::vector<PrimitiveType>& operand_types,
XlaBuilder* builder, XlaOpGenerator generator) {
XlaBuilder* builder, XlaCompareOp generator) {
CHECK_NE(operand_types.size(), 0);
std::vector<absl::optional<XlaOpGenerator>> generators(operand_types.size());
std::vector<absl::optional<XlaCompareOp>> generators(operand_types.size());
generators[0] = generator;
return CreateScalarComparisonComputation(name, operand_types, generators,
builder);
@ -119,7 +47,7 @@ XlaComputation CreateScalarComparisonComputation(
XlaComputation CreateScalarComparisonComputation(
const string& name, const std::vector<PrimitiveType>& operand_types,
const std::vector<absl::optional<XlaOpGenerator>>& generators,
const std::vector<absl::optional<XlaCompareOp>>& generators,
XlaBuilder* builder) {
// Create a default computation where we compare only the first two
// parameters of type 'operand_types[0]'.
@ -146,7 +74,6 @@ XlaComputation CreateScalarComparisonComputation(
absl::StrCat("p.", parameter_count, ".lhs"));
auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
absl::StrCat("p.", parameter_count, ".rhs"));
ConvertFloatingPoint(operand_type, &lhs_param, &rhs_param);
lhs_params.emplace_back(lhs_param);
rhs_params.emplace_back(rhs_param);
if (generators[parameter_count].has_value()) {
@ -169,7 +96,8 @@ XlaComputation CreateScalarComparisonComputation(
generators[i].value()(lhs_params[i], rhs_params[i], {}),
result);
if (i != last_generator_index) {
param_equal = And(param_equal, Eq(lhs_params[i], rhs_params[i]));
param_equal =
And(param_equal, EqTotalOrder(lhs_params[i], rhs_params[i]));
}
}
}
@ -181,14 +109,14 @@ XlaComputation CreateScalarComparisonComputation(
XlaComputation CreateScalarLtComputation(
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
return CreateScalarComparisonComputation("compare-less-than", operand_types,
builder, Lt);
builder, LtTotalOrder);
}
// Creates a scalar greater-than computation and returns it.
XlaComputation CreateScalarGtComputation(
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
return CreateScalarComparisonComputation("compare-greater-than",
operand_types, builder, Gt);
return CreateScalarComparisonComputation(
"compare-greater-than", operand_types, builder, GtTotalOrder);
}
} // namespace xla

View File

@ -43,14 +43,13 @@ XlaComputation CreateScalarGtComputation(
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder);
// Creates a scalar comparison computation and returns it. This function takes
// an std::vector<absl::optional<XlaOpGenerator>> and compare the operands
// where the generator isn't nullopt with the specified comparator
// at that location.
// a vector of comparator functions to compare the operands where the function
// isn't nullopt with the specified comparator at that location.
XlaComputation CreateScalarComparisonComputation(
const string& name, const std::vector<PrimitiveType>& operand_types,
const std::vector<
absl::optional<XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>)>>&
generators,
comparators,
XlaBuilder* builder);
} // namespace xla

View File

@ -577,7 +577,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
absl::optional<ComparisonDirection> direction) {
absl::optional<ComparisonDirection> direction,
absl::optional<Comparison::Type> type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
@ -635,7 +636,11 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
return InvalidArgument(
"kCompare expects a ComparisonDirection, but none provided.");
}
return Compare(shape, updated_lhs, updated_rhs, *direction);
if (type == absl::nullopt) {
return Compare(shape, updated_lhs, updated_rhs, *direction);
} else {
return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
}
}
if (direction.has_value()) {
@ -658,8 +663,16 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction) {
return Compare(shape, lhs, rhs, direction,
Comparison::DefaultComparisonType(shape.element_type()));
}
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction,
Comparison::Type type) {
HloInstructionProto instr;
instr.set_comparison_direction(ComparisonDirectionToString(direction));
instr.set_comparison_type(ComparisonTypeToString(type));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
}
@ -3512,31 +3525,71 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
}
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
compare_type);
}
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
}
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
compare_type);
}
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
}
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
compare_type);
}
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
}
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
compare_type);
}
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
}
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
compare_type);
}
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
}
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
Comparison::Type::kFloatTotalOrder);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction) {
@ -3544,6 +3597,13 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
broadcast_dimensions, direction);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction, Comparison::Type compare_type) {
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
broadcast_dimensions, direction, compare_type);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
return Compare(lhs, rhs, {}, direction);
}

View File

@ -792,14 +792,17 @@ class XlaBuilder {
// broadcast_dimensions specifies which dimensions to use for broadcasting
// when the operation is between tensors of different ranks. The direction is
// only used if opcode is kCompare.
XlaOp BinaryOp(
HloOpcode binop, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
absl::optional<Comparison::Direction> direction = absl::nullopt);
XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
absl::optional<ComparisonDirection> direction = absl::nullopt,
absl::optional<Comparison::Type> type = absl::nullopt);
// Internal helper method for binary op compare without broadcast dimensions.
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
Comparison::Direction direction);
ComparisonDirection direction);
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction,
Comparison::Type type);
// Internal helper method that does the building for an arbitrary binary op
// with same ranked operands that doesn't broadcast.
@ -965,22 +968,13 @@ class XlaBuilder {
friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
friend XlaOp Eq(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Ne(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Ge(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Gt(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Lt(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Le(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction);
friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction,
Comparison::Type compare_type);
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config);
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
@ -1574,29 +1568,44 @@ XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a comparison instruction onto the computation (optionally without
// broadcast_dimensions for consistency with others).
XlaOp Compare(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction, Comparison::Type compare_type);
XlaOp Compare(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction);

View File

@ -54,32 +54,59 @@ StatusOr<Comparison::Direction> StringToComparisonDirection(
return it->second;
}
Comparison::Comparison(Direction dir, PrimitiveType type) : dir_(dir) {
StatusOr<Comparison::Type> StringToComparisonType(
absl::string_view compare_type_name) {
static auto* type_map = new absl::flat_hash_map<string, Comparison::Type>({
{"FLOAT", Comparison::Type::kFloat},
{"TOTALORDER", Comparison::Type::kFloatTotalOrder},
{"SIGNED", Comparison::Type::kSigned},
{"UNSIGNED", Comparison::Type::kUnsigned},
});
auto it = type_map->find(compare_type_name);
if (it == type_map->end()) {
return InvalidArgument("Unknown comparison type: %s", compare_type_name);
}
return it->second;
}
std::string ComparisonTypeToString(Comparison::Type type) {
switch (type) {
case Comparison::Type::kFloat:
return "FLOAT";
case Comparison::Type::kFloatTotalOrder:
return "TOTALORDER";
case Comparison::Type::kSigned:
return "SIGNED";
case Comparison::Type::kUnsigned:
return "UNSIGNED";
}
}
Comparison::Comparison(Direction dir, PrimitiveType type)
: dir_(dir), type_(DefaultComparisonType(type)) {}
Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) {
switch (type) {
case S8:
case S16:
case S32:
case S64:
type_ = Type::kSigned;
break;
return Type::kSigned;
case PRED:
case U8:
case U16:
case U32:
case U64:
type_ = Type::kUnsigned;
break;
return Type::kUnsigned;
case F16:
case F32:
case BF16:
case F64:
case C64:
case C128:
type_ = Type::kFloat;
break;
return Type::kFloat;
default:
LOG(FATAL) << "Unsupported comparison mode."
<< ComparisonDirectionToString(dir) << ":"
<< PrimitiveType_Name(type) << "\n";
}
}
@ -164,20 +191,6 @@ bool Comparison::IsAntireflexive() const {
}
}
/* static */ const char* Comparison::ComparisonTypeToString(
Comparison::Type type) {
switch (type) {
case Type::kFloat:
return "f";
case Type::kFloatTotalOrder:
return "ft";
case Type::kSigned:
return "s";
case Type::kUnsigned:
return "u";
}
}
std::string Comparison::ToString(std::string prefix1,
std::string prefix2) const {
return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 +

View File

@ -103,11 +103,11 @@ class Comparison {
bool Compare(const T a, const T b) const {
return GetComparator<T>()(a, b);
}
static Type DefaultComparisonType(PrimitiveType t);
private:
static Direction Converse(Direction dir);
static Direction Inverse(Direction dir);
static const char* ComparisonTypeToString(Type type);
const Direction dir_;
Type type_;
@ -117,10 +117,14 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) {
return os << cmp.ToString();
}
string ComparisonDirectionToString(Comparison::Direction direction);
std::string ComparisonTypeToString(Comparison::Type type);
StatusOr<Comparison::Direction> StringToComparisonDirection(
absl::string_view direction_name);
StatusOr<Comparison::Type> StringToComparisonType(
absl::string_view compare_type_name);
using ComparisonDirection = Comparison::Direction;
} // namespace xla

View File

@ -1235,7 +1235,10 @@ floating-point types.
Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
(less-than).
(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder,
GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities,
except that they additionally support a total order over the floating point
numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN.
Arguments | Type | Semantics
--------- | ------- | ----------------------------------------

View File

@ -112,6 +112,21 @@ xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) {
}
}
xla::PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth) {
switch (src_bitwidth) {
case 8:
return xla::S8;
case 16:
return xla::S16;
case 32:
return xla::S32;
case 64:
return xla::S64;
default:
return xla::PRIMITIVE_TYPE_INVALID;
}
}
PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
switch (complex_type) {
case C64:

View File

@ -153,6 +153,8 @@ int BitWidth(PrimitiveType type);
PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth);
// Returns the real, imag component type underlying the given complex type.
// LOG(FATAL)'s if complex_type is not complex.
PrimitiveType ComplexComponentType(PrimitiveType complex_type);

View File

@ -1700,7 +1700,10 @@ cc_library(
cc_library(
name = "hlo_creation_utils",
srcs = ["hlo_creation_utils.cc"],
hdrs = ["hlo_creation_utils.h"],
hdrs = [
"hlo_creation_utils.h",
"//tensorflow/compiler/xla:literal_util",
],
deps = [
":hlo",
":hlo_module_config",
@ -1816,6 +1819,21 @@ cc_library(
],
)
cc_library(
name = "comparison_expander",
srcs = ["comparison_expander.cc"],
hdrs = ["comparison_expander.h"],
deps = [
":hlo",
":hlo_creation_utils",
":hlo_pass",
":op_expander_pass",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client/lib:comparators",
],
)
cc_library(
name = "scatter_expander",
srcs = ["scatter_expander.cc"],

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/comparison_expander.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
HloInstruction* BitcastConvertFloatingPointToIntegral(
HloComputation* computation, HloInstruction* value,
const Shape& signed_shape, const Shape& unsigned_shape,
HloInstruction* zero, HloInstruction* max_value) {
// Switch from a floating point value to a integer value in such a way that
// when using the integer value to compare, we get the same result for normal
// values, and -Nan is treated as the smallest value, and Nan is treated as
// the largest value.
// If f is a float, and
// x = bit_cast<int32>(f);
// y = x < 0 ? numeric_limits<int32>::max() - x : x;
// then y is ordered as an int32 such that finite values have the obvious
// order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
// and end of the ordering.
// Note that in order to avoid -x to overflow, we calculate
// numeric_limits<int32>::max() - x as unsigned, and then convert back to
// signed.
auto signed_value = computation->AddInstruction(
HloInstruction::CreateBitcastConvert(signed_shape, value));
auto unsigned_value = computation->AddInstruction(
HloInstruction::CreateBitcastConvert(unsigned_shape, value));
auto flipped_value = computation->AddInstruction(HloInstruction::CreateBinary(
unsigned_shape, HloOpcode::kSubtract, max_value, unsigned_value));
flipped_value = computation->AddInstruction(
HloInstruction::CreateBitcastConvert(signed_shape, flipped_value));
auto compare_shape = signed_shape;
compare_shape.set_element_type(PRED);
auto is_negative = computation->AddInstruction(HloInstruction::CreateCompare(
compare_shape, signed_value, zero, ComparisonDirection::kLt));
return computation->AddInstruction(
HloInstruction::CreateTernary(signed_shape, HloOpcode::kSelect,
is_negative, flipped_value, signed_value));
}
bool ComparisonExpander::InstructionMatchesPattern(
HloInstruction* instruction) {
if (HloCompareInstruction* compare =
dynamic_cast<HloCompareInstruction*>(instruction)) {
HloInstruction* lhs = instruction->operands()[0];
if (compare->type() == Comparison::Type::kFloatTotalOrder &&
primitive_util::IsFloatingPointType(lhs->shape().element_type())) {
return true;
}
}
return false;
}
StatusOr<HloInstruction*> ComparisonExpander::ExpandInstruction(
HloInstruction* instruction) {
CHECK(instruction->opcode() == HloOpcode::kCompare);
HloCompareInstruction* compare =
static_cast<HloCompareInstruction*>(instruction);
CHECK(compare->type() == Comparison::Type::kFloatTotalOrder);
HloComputation* computation = instruction->parent();
HloInstruction* lhs = instruction->operands()[0];
HloInstruction* rhs = instruction->operands()[1];
Shape compare_shape = lhs->shape();
PrimitiveType compare_type = compare_shape.element_type();
CHECK(primitive_util::IsFloatingPointType(compare_type));
// Special-case handling for BF16. We currently do not support direct
// comparisons with BF16, so we convert to F32 and then use the F32
// comparison logic.
if (compare_type == BF16) {
compare_type = F32;
compare_shape.set_element_type(compare_type);
lhs = computation->AddInstruction(
HloInstruction::CreateConvert(compare_shape, lhs));
rhs = computation->AddInstruction(
HloInstruction::CreateConvert(compare_shape, rhs));
}
int64 bit_width = primitive_util::BitWidth(compare_type);
PrimitiveType signed_type =
primitive_util::SignedIntegralTypeForBitWidth(bit_width);
PrimitiveType unsigned_type =
primitive_util::UnsignedIntegralTypeForBitWidth(bit_width);
auto signed_shape = compare_shape;
signed_shape.set_element_type(signed_type);
auto unsigned_shape = compare_shape;
unsigned_shape.set_element_type(unsigned_type);
auto zero_value = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type)));
zero_value = computation->AddInstruction(HloInstruction::CreateBroadcast(
signed_shape, zero_value, zero_value->shape().dimensions()));
auto max_signed = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type)));
auto max_shape = max_signed->shape();
max_shape.set_element_type(unsigned_type);
auto max_unsigned = computation->AddInstruction(
HloInstruction::CreateConvert(max_shape, max_signed));
auto max_value = computation->AddInstruction(HloInstruction::CreateBroadcast(
unsigned_shape, max_unsigned, max_shape.dimensions()));
lhs = BitcastConvertFloatingPointToIntegral(
computation, lhs, signed_shape, unsigned_shape, zero_value, max_value);
rhs = BitcastConvertFloatingPointToIntegral(
computation, rhs, signed_shape, unsigned_shape, zero_value, max_value);
auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare(
instruction->shape(), lhs, rhs, compare->direction(),
Comparison::Type::kSigned));
VLOG(2) << "New comparison instruction for total order:"
<< new_compare->ToString() << "\n";
return new_compare;
}
} // namespace xla

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
#include <utility>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla {
// A pass which performs expansion of the comparison operator to support total
// order comparison of floating point numbers.
class ComparisonExpander : public OpExpanderPass {
public:
explicit ComparisonExpander() = default;
~ComparisonExpander() override = default;
absl::string_view name() const override { return "comparison-expander"; }
private:
// Returns `true` if `instruction` should be expanded by this pass.
bool InstructionMatchesPattern(HloInstruction* instruction) override;
// Returns a replacement for `instruction`, or nullptr if no replacement is
// needed (e.g. only the to_apply subcomputation of the instruction was
// modified).
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_

View File

@ -145,6 +145,7 @@ cc_library(
"//tensorflow/compiler/xla/service:conditional_to_select",
"//tensorflow/compiler/xla/service:slow_operation_alarm",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla/service:comparison_expander",
"//tensorflow/compiler/xla/service:slice_sinker",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla:literal",

View File

@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include "tensorflow/compiler/xla/service/comparison_expander.h"
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
@ -261,6 +262,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<ConditionalToSelect>();
pipeline.AddPass<MapInliner>();
pipeline.AddPass<ComparisonExpander>();
pipeline.AddPass<CholeskyExpander>();
pipeline.AddPass<TriangularSolveExpander>();

View File

@ -1168,6 +1168,7 @@ cc_library(
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:comparison_expander",
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_4d_expander",

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/comparison_expander.h"
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
@ -140,6 +141,9 @@ Status GpuCompiler::OptimizeHloModule(
pipeline.AddPass<RngExpander>();
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
// Comparison total order expander
pipeline.AddPass<ComparisonExpander>();
// Remove zero-sized HLO from the input so that other passes don't have to
// handle it.
pipeline.AddPass<ZeroSizedHloElimination>();

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 72
// Next ID: 73
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -248,6 +248,9 @@ message HloInstructionProto {
// RNG algorithm used by kRngBitGenerator.
xla.RandomAlgorithm rng_algorithm = 70;
// The comparison type used for kCompare.
string comparison_type = 72;
}
// Serialization of HloComputation.

View File

@ -174,8 +174,19 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
comparison_direction,
StringToComparisonDirection(proto.comparison_direction()));
}
instruction =
CreateCompare(shape, operands(0), operands(1), *comparison_direction);
auto comparison_type_str = proto.comparison_type();
if (!comparison_type_str.empty()) {
// If a comparison type is specified, it *must* be valid.
TF_ASSIGN_OR_RETURN(auto comparison_type,
StringToComparisonType(comparison_type_str));
instruction = CreateCompare(shape, operands(0), operands(1),
*comparison_direction, comparison_type);
} else {
// Allow the specify of comparison type to be optional.
// The comparison type will be determined by the types of the operands.
instruction = CreateCompare(shape, operands(0), operands(1),
*comparison_direction);
}
break;
}
case HloOpcode::kTriangularSolve: {
@ -926,8 +937,9 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
ComparisonDirection direction) {
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
ComparisonDirection direction, absl::optional<Comparison::Type> type) {
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
type);
}
/* static */ std::unique_ptr<HloInstruction>

View File

@ -595,7 +595,8 @@ class HloInstruction {
// Creates a compare op, performing the comparison specified in direction.
static std::unique_ptr<HloInstruction> CreateCompare(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
Comparison::Direction direction);
Comparison::Direction direction,
absl::optional<Comparison::Type> type = absl::nullopt);
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
const Shape& shape, HloInstruction* a, HloInstruction* b,

View File

@ -204,12 +204,13 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
fft_length_);
}
HloCompareInstruction::HloCompareInstruction(const Shape& shape,
HloInstruction* lhs,
HloInstruction* rhs,
ComparisonDirection direction)
HloCompareInstruction::HloCompareInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
ComparisonDirection direction, absl::optional<Comparison::Type> type)
: HloInstruction(HloOpcode::kCompare, shape),
compare_(direction, lhs->shape().element_type()) {
compare_(direction, type ? (*type)
: Comparison::DefaultComparisonType(
lhs->shape().element_type())) {
AppendOperand(lhs);
AppendOperand(rhs);
}
@ -218,12 +219,21 @@ HloInstructionProto HloCompareInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_comparison_direction(
ComparisonDirectionToString(compare_.GetDirection()));
proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
return proto;
}
std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {StrCat("direction=", ComparisonDirectionToString(direction()))};
std::vector<string> result;
result.push_back(
StrCat("direction=", ComparisonDirectionToString(direction())));
if (compare_.GetType() !=
Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
result.push_back(
StrCat("type=", ComparisonTypeToString(compare_.GetType())));
}
return result;
}
bool HloCompareInstruction::IdenticalSlowPath(
@ -238,8 +248,8 @@ std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
new_operands[1], direction());
return absl::make_unique<HloCompareInstruction>(
shape, new_operands[0], new_operands[1], direction(), type());
}
namespace {

View File

@ -136,8 +136,10 @@ class HloCompareInstruction : public HloInstruction {
public:
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
HloInstruction* rhs,
ComparisonDirection direction);
ComparisonDirection direction,
absl::optional<Comparison::Type> type);
ComparisonDirection direction() const { return compare_.GetDirection(); }
Comparison::Type type() const { return compare_.GetType(); }
HloInstructionProto ToProto() const override;
private:

View File

@ -194,6 +194,7 @@ class HloParserImpl : public HloParser {
kBracedHloComputationList,
kFftType,
kComparisonDirection,
kComparisonType,
kWindow,
kConvolutionDimensionNumbers,
kSharding,
@ -327,6 +328,7 @@ class HloParserImpl : public HloParser {
bool ParseOpcode(HloOpcode* result);
bool ParseFftType(FftType* result);
bool ParseComparisonDirection(ComparisonDirection* result);
bool ParseComparisonType(Comparison::Type* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
bool ParseRandomAlgorithm(RandomAlgorithm* result);
@ -1362,14 +1364,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
}
case HloOpcode::kCompare: {
optional<ComparisonDirection> direction;
optional<Comparison::Type> type;
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
&direction};
attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
shape, operands[0], operands[1], *direction));
shape, operands[0], operands[1], *direction, type));
break;
}
case HloOpcode::kCholesky: {
@ -3018,6 +3022,14 @@ bool HloParserImpl::ParseAttributeHelper(
->emplace(result);
return true;
}
case AttrTy::kComparisonType: {
Comparison::Type result;
if (!ParseComparisonType(&result)) {
return false;
}
static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kEnum: {
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects an enumeration value");
@ -4145,6 +4157,21 @@ bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
return true;
}
bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
VLOG(1) << "ParseComparisonType";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects comparison type");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToComparisonType(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects comparison type but sees: %s", val));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
VLOG(3) << "ParseFusionKind";
if (lexer_.GetKind() != TokKind::kIdent) {

View File

@ -230,7 +230,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated}
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
}
@ -512,7 +512,7 @@ R"(HloModule R4F32OverlapSmall_module
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER
}
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {

View File

@ -34,6 +34,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:cholesky_expander",
"//tensorflow/compiler/xla/service:comparison_expander",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:custom_call_target_registry",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include "tensorflow/compiler/xla/service/comparison_expander.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
@ -81,6 +82,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<CholeskyExpander>();
pipeline.AddPass<ComparisonExpander>();
pipeline.AddPass<TriangularSolveExpander>();
pipeline.AddPass<LayoutAssignment>(
hlo_module->mutable_entry_computation_layout(),

View File

@ -1203,6 +1203,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN});
EqTotalOrder(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {false, false, true, true, false}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
XlaBuilder builder(TestName());
auto lhs = ConstantR1<float>(&builder, {});
@ -1222,6 +1232,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
auto lhs =
ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f, 6.0f});
auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN, -NAN});
GeTotalOrder(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {false, true, true, true, false, true},
{});
}
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());