[XLA] Replace individual comparison HLO ops with a single compare op.
PiperOrigin-RevId: 237318727
This commit is contained in:
parent
add7a1a911
commit
57467ada28
@ -57,6 +57,24 @@ xla_proto_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "comparison_util",
|
||||
srcs = [
|
||||
"comparison_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"comparison_util.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":statusor",
|
||||
":types",
|
||||
":util",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "execution_options_util",
|
||||
srcs = [
|
||||
|
@ -212,6 +212,7 @@ cc_library(
|
||||
":padding",
|
||||
":sharding_builder",
|
||||
":xla_computation",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
|
@ -480,7 +480,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
absl::optional<ComparisonDirection> direction) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
|
||||
@ -489,6 +490,17 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
ShapeInference::InferBinaryOpShape(
|
||||
binop, lhs_shape, rhs_shape, broadcast_dimensions));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
if (binop == HloOpcode::kCompare) {
|
||||
if (!direction.has_value()) {
|
||||
return InvalidArgument(
|
||||
"kCompare expects a ComparisonDirection, but none provided.");
|
||||
}
|
||||
instr.set_comparison_direction(ComparisonDirectionToString(*direction));
|
||||
} else if (direction.has_value()) {
|
||||
return InvalidArgument(
|
||||
"A comparison direction is provided for a non-compare opcode: %s.",
|
||||
HloOpcodeString(binop));
|
||||
}
|
||||
|
||||
const int64 lhs_rank = lhs_shape.rank();
|
||||
const int64 rhs_rank = rhs_shape.rank();
|
||||
@ -2908,38 +2920,39 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
|
||||
|
||||
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
||||
}
|
||||
|
||||
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
|
||||
}
|
||||
|
||||
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
|
||||
}
|
||||
|
||||
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
|
||||
}
|
||||
|
||||
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
|
||||
}
|
||||
|
||||
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
||||
}
|
||||
|
||||
XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
|
||||
broadcast_dimensions, direction);
|
||||
}
|
||||
|
||||
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
|
||||
@ -596,9 +597,11 @@ class XlaBuilder {
|
||||
|
||||
// Internal helper method that does the building for an arbitrary binary op.
|
||||
// broadcast_dimensions specifies which dimensions to use for broadcasting
|
||||
// when the operation is between tensors of different ranks.
|
||||
// when the operation is between tensors of different ranks. The direction is
|
||||
// only used if opcode is kCompare.
|
||||
XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
absl::optional<ComparisonDirection> direction = absl::nullopt);
|
||||
|
||||
// Internal helper method that does the building for an arbitrary ternary op.
|
||||
XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
@ -767,6 +770,9 @@ class XlaBuilder {
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction);
|
||||
friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
||||
const PrecisionConfig* precision_config);
|
||||
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||
@ -1279,6 +1285,11 @@ XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
||||
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a comparison instruction onto the computation.
|
||||
XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction);
|
||||
|
||||
// Enqueues a dot instruction onto the computation.
|
||||
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
|
57
tensorflow/compiler/xla/comparison_util.cc
Normal file
57
tensorflow/compiler/xla/comparison_util.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
std::string ComparisonDirectionToString(ComparisonDirection direction) {
|
||||
switch (direction) {
|
||||
case ComparisonDirection::kEq:
|
||||
return "EQ";
|
||||
case ComparisonDirection::kNe:
|
||||
return "NE";
|
||||
case ComparisonDirection::kGe:
|
||||
return "GE";
|
||||
case ComparisonDirection::kGt:
|
||||
return "GT";
|
||||
case ComparisonDirection::kLe:
|
||||
return "LE";
|
||||
case ComparisonDirection::kLt:
|
||||
return "LT";
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<ComparisonDirection> StringToComparisonDirection(
|
||||
absl::string_view direction_name) {
|
||||
static auto* direction_map =
|
||||
new absl::flat_hash_map<string, ComparisonDirection>({
|
||||
{"EQ", ComparisonDirection::kEq},
|
||||
{"NE", ComparisonDirection::kNe},
|
||||
{"GE", ComparisonDirection::kGe},
|
||||
{"GT", ComparisonDirection::kGt},
|
||||
{"LE", ComparisonDirection::kLe},
|
||||
{"LT", ComparisonDirection::kLt},
|
||||
});
|
||||
auto it = direction_map->find(direction_name);
|
||||
if (it == direction_map->end()) {
|
||||
return InvalidArgument("Unknown comparison direction: %s", direction_name);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace xla
|
42
tensorflow/compiler/xla/comparison_util.h
Normal file
42
tensorflow/compiler/xla/comparison_util.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
|
||||
|
||||
#include "absl/base/macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Represents different comparison operations.
|
||||
enum class ComparisonDirection : uint8 {
|
||||
kEq,
|
||||
kNe,
|
||||
kGe,
|
||||
kGt,
|
||||
kLe,
|
||||
kLt,
|
||||
};
|
||||
|
||||
string ComparisonDirectionToString(ComparisonDirection direction);
|
||||
|
||||
StatusOr<ComparisonDirection> StringToComparisonDirection(
|
||||
absl::string_view direction_name);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
|
@ -334,6 +334,7 @@ cc_library(
|
||||
":hlo_proto",
|
||||
":name_uniquer",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
@ -3301,6 +3302,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -873,9 +873,9 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
||||
computation, a->shape().element_type(), a->shape().dimensions());
|
||||
|
||||
auto* dividend_is_negative =
|
||||
computation->AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a,
|
||||
zero_like_a));
|
||||
computation->AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
|
||||
ComparisonDirection::kLt));
|
||||
|
||||
auto* negated_dividend = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
|
||||
@ -2475,9 +2475,9 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
|
||||
computation, a->shape().element_type(), a->shape().dimensions());
|
||||
|
||||
auto* dividend_is_negative =
|
||||
computation->AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a,
|
||||
zero_like_a));
|
||||
computation->AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
|
||||
ComparisonDirection::kLt));
|
||||
|
||||
auto* negated_dividend = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
|
||||
|
@ -221,7 +221,7 @@ HloModule foobar
|
||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||
%constant.0 = s32[] constant(0)
|
||||
%constant.1 = s32[] constant(1)
|
||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
|
||||
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
|
||||
}
|
||||
|
||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||
@ -258,7 +258,7 @@ HloModule foobar
|
||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||
%constant.0 = s32[] constant(0)
|
||||
%constant.1 = s32[] constant(1)
|
||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
|
||||
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
|
||||
}
|
||||
|
||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||
@ -296,7 +296,7 @@ HloModule foobar
|
||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||
%constant.0 = s32[] constant(0)
|
||||
%constant.1 = s32[] constant(1)
|
||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
|
||||
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
|
||||
}
|
||||
|
||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||
|
@ -109,8 +109,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
|
||||
HloInstruction* add1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
|
||||
HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b));
|
||||
HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq));
|
||||
HloInstruction* sel = builder.AddInstruction(
|
||||
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
|
||||
HloInstruction* xpose =
|
||||
@ -574,8 +574,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
|
||||
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
||||
auto cond_dot =
|
||||
builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
|
||||
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
||||
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond.AddInstruction(
|
||||
@ -583,9 +583,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
|
||||
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
|
||||
{1, 1}))))));
|
||||
builder_cond.AddInstruction(
|
||||
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||
cond_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||
ComparisonDirection::kGt));
|
||||
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
||||
|
||||
auto builder_body = HloComputation::Builder("body");
|
||||
@ -631,8 +632,8 @@ TEST_F(BFloat16PropagationTest,
|
||||
auto builder_cond = HloComputation::Builder("cond");
|
||||
auto cond_param = builder_cond.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
||||
builder_cond.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
||||
builder_cond.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
||||
@ -642,7 +643,8 @@ TEST_F(BFloat16PropagationTest,
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
|
||||
{1, 1}))))));
|
||||
{1, 1})))),
|
||||
ComparisonDirection::kGt));
|
||||
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
||||
|
||||
auto builder_body = HloComputation::Builder("body");
|
||||
@ -705,8 +707,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
|
||||
auto cond_dot =
|
||||
builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
|
||||
builder_cond.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
||||
builder_cond.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond.AddInstruction(
|
||||
@ -714,9 +716,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
|
||||
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
|
||||
{1, 1}))))));
|
||||
builder_cond.AddInstruction(
|
||||
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||
cond_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||
ComparisonDirection::kGt));
|
||||
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
||||
|
||||
auto builder_body = HloComputation::Builder("body");
|
||||
@ -800,8 +803,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
||||
shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
|
||||
auto cond0_dot =
|
||||
builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
|
||||
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
||||
builder_cond0.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}),
|
||||
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond0.AddInstruction(
|
||||
@ -809,9 +812,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
||||
cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond0.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2},
|
||||
{1, 1}))))));
|
||||
builder_cond0.AddInstruction(
|
||||
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||
cond0_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||
ComparisonDirection::kGt));
|
||||
auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
|
||||
|
||||
// Condition computation for the second while.
|
||||
@ -828,8 +832,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
||||
shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
|
||||
auto cond1_dot =
|
||||
builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
|
||||
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
||||
builder_cond1.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}),
|
||||
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond1.AddInstruction(
|
||||
@ -837,9 +841,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
||||
cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {}),
|
||||
builder_cond1.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2},
|
||||
{1, 1}))))));
|
||||
builder_cond1.AddInstruction(
|
||||
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||
cond1_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||
ComparisonDirection::kGt));
|
||||
auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
|
||||
|
||||
// Body computation shared by both whiles.
|
||||
|
@ -190,8 +190,9 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
|
||||
auto index = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
|
||||
const4, ComparisonDirection::kLt));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
@ -1863,8 +1864,8 @@ class WhileBufferAssignmentTest : public HloTestBase {
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
|
||||
auto ten = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
@ -2135,8 +2136,9 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
|
||||
auto param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
|
||||
const4, ComparisonDirection::kLt));
|
||||
return builder.Build();
|
||||
};
|
||||
|
||||
@ -2530,7 +2532,7 @@ while_condition {
|
||||
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
||||
get-tuple-element = s32[] get-tuple-element(state), index=0
|
||||
get-tuple-element.1 = s32[] constant(3)
|
||||
ROOT less-than.339.338 = pred[] less-than(get-tuple-element, get-tuple-element.1)
|
||||
ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
|
@ -83,8 +83,9 @@ class CallGraphTest : public HloTestBase {
|
||||
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
|
||||
zero, ComparisonDirection::kGt));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
|
@ -191,8 +191,9 @@ HloInstruction* GetExpandedFilterMask(
|
||||
// linspace to create a diagonal predicate.
|
||||
Shape predicate_shape = ShapeUtil::MakeShape(
|
||||
PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
|
||||
return add_instruction(HloInstruction::CreateBinary(
|
||||
predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2));
|
||||
return add_instruction(HloInstruction::CreateCompare(
|
||||
predicate_shape, broadcasted_mask1, broadcasted_mask2,
|
||||
ComparisonDirection::kEq));
|
||||
}
|
||||
|
||||
// This function handles batch_group_counts which are relevant only for
|
||||
|
@ -420,9 +420,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
auto induction_variable =
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
limit_const->shape(), loop_state, 0));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
|
||||
induction_variable, limit_const));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
condition_result_shape_, induction_variable, limit_const,
|
||||
ComparisonDirection::kLt));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
@ -1842,7 +1842,7 @@ HloModule TokensShouldNotBeCopied
|
||||
%param = (s32[], token[]) parameter(0)
|
||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||
%constant = s32[] constant(42)
|
||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
||||
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %TokensShouldNotBeCopied () -> s32[] {
|
||||
@ -2060,7 +2060,7 @@ if-condition.v4 {
|
||||
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
||||
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
|
||||
constant.4 = s32[] constant(0)
|
||||
ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
|
||||
ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
|
||||
}
|
||||
|
||||
_functionalize_body_1__.v28 {
|
||||
@ -2070,7 +2070,7 @@ _functionalize_body_1__.v28 {
|
||||
add.4 = s32[] add(get-tuple-element.68, constant.7)
|
||||
get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
|
||||
get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
|
||||
less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70)
|
||||
less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
|
||||
constant.8 = s32[] constant(0)
|
||||
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
|
||||
get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
|
||||
@ -2087,7 +2087,7 @@ cond_wrapper.v3.1 {
|
||||
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
|
||||
constant.11 = s32[] constant(7)
|
||||
ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11)
|
||||
ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
|
||||
}
|
||||
|
||||
_functionalize_body_2__.v25 {
|
||||
@ -2110,7 +2110,7 @@ cond_wrapper.v3.2 {
|
||||
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
|
||||
constant.13 = s32[] constant(5)
|
||||
ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13)
|
||||
ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
|
||||
}
|
||||
|
||||
ENTRY TestComputation {
|
||||
@ -2142,7 +2142,7 @@ if-condition.v4 {
|
||||
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
||||
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
|
||||
constant.4 = s32[] constant(0)
|
||||
ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
|
||||
ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
|
||||
}
|
||||
|
||||
if-body.v5.1 {
|
||||
@ -2159,7 +2159,7 @@ if-condition.v4.1 {
|
||||
p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
||||
get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
|
||||
constant.6 = s32[] constant(1)
|
||||
ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6)
|
||||
ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
|
||||
}
|
||||
|
||||
_functionalize_body_1__.v28 {
|
||||
@ -2169,7 +2169,7 @@ _functionalize_body_1__.v28 {
|
||||
add.4 = s32[] add(get-tuple-element.72, constant.7)
|
||||
get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
|
||||
get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
|
||||
less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74)
|
||||
less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
|
||||
constant.8 = s32[] constant(0)
|
||||
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
|
||||
get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
|
||||
@ -2187,7 +2187,7 @@ cond_wrapper.v3.1 {
|
||||
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
|
||||
constant.11 = s32[] constant(7)
|
||||
ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11)
|
||||
ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
|
||||
}
|
||||
|
||||
_functionalize_body_2__.v25 {
|
||||
@ -2210,7 +2210,7 @@ cond_wrapper.v3.2 {
|
||||
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
|
||||
constant.13 = s32[] constant(5)
|
||||
ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13)
|
||||
ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
|
||||
}
|
||||
|
||||
ENTRY TestComputation {
|
||||
|
@ -29,7 +29,7 @@ HloModule KeyValueSort
|
||||
compare {
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
|
@ -75,7 +75,7 @@ ENTRY TestComputation {
|
||||
broadcast = f32[42] broadcast(add), dimensions={}
|
||||
slice = f32[1] slice(broadcast), slice={[1:2]}
|
||||
copy = f32[] copy(arg)
|
||||
eq = pred[] equal-to(arg, gte)
|
||||
eq = pred[] compare(arg, gte), direction=EQ
|
||||
neg = f32[] negate(arg)
|
||||
ROOT convert = f64[] convert(f32[] arg)
|
||||
})";
|
||||
|
@ -68,8 +68,8 @@ class DynamicDimensionInferenceTest : public HloTestBase {
|
||||
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
|
||||
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
|
||||
embedded_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs));
|
||||
embedded_builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
|
||||
return module_->AddEmbeddedComputation(embedded_builder.Build());
|
||||
}
|
||||
|
||||
|
@ -160,9 +160,10 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
HloInstruction* broadcasted_effective_size =
|
||||
computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
mask_shape, dynamic_size, {}));
|
||||
HloInstruction* pred = computation->AddInstruction(
|
||||
HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota,
|
||||
broadcasted_effective_size));
|
||||
HloInstruction* pred =
|
||||
computation->AddInstruction(HloInstruction::CreateCompare(
|
||||
pred_shape, iota, broadcasted_effective_size,
|
||||
ComparisonDirection::kLt));
|
||||
|
||||
HloInstruction* broadcasted_identity_value =
|
||||
computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
|
@ -719,25 +719,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||
// We use ordered comparisons for everything except kNe, where we use an
|
||||
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
||||
// matches C++'s semantics.
|
||||
case HloOpcode::kEq:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
|
||||
rhs_value, b_);
|
||||
case HloOpcode::kNe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
|
||||
rhs_value, b_);
|
||||
case HloOpcode::kLt:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
|
||||
rhs_value, b_);
|
||||
case HloOpcode::kGt:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
|
||||
rhs_value, b_);
|
||||
case HloOpcode::kLe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
|
||||
rhs_value, b_);
|
||||
case HloOpcode::kGe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
|
||||
rhs_value, b_);
|
||||
|
||||
case HloOpcode::kCompare: {
|
||||
switch (op->comparison_direction()) {
|
||||
case ComparisonDirection::kEq:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
|
||||
rhs_value, b_);
|
||||
case ComparisonDirection::kNe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
|
||||
rhs_value, b_);
|
||||
case ComparisonDirection::kLt:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
|
||||
rhs_value, b_);
|
||||
case ComparisonDirection::kGt:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
|
||||
rhs_value, b_);
|
||||
case ComparisonDirection::kLe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
|
||||
rhs_value, b_);
|
||||
case ComparisonDirection::kGe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
|
||||
rhs_value, b_);
|
||||
}
|
||||
}
|
||||
case HloOpcode::kMaximum:
|
||||
return EmitFloatMax(lhs_value, rhs_value);
|
||||
case HloOpcode::kMinimum:
|
||||
@ -839,21 +842,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
|
||||
// We use ordered comparisons for everything except kNe, where we use an
|
||||
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
||||
// matches C++'s semantics.
|
||||
case HloOpcode::kEq:
|
||||
return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
||||
EmitExtractReal(lhs_value),
|
||||
EmitExtractReal(rhs_value), b_),
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
||||
EmitExtractImag(lhs_value),
|
||||
EmitExtractImag(rhs_value), b_));
|
||||
case HloOpcode::kNe:
|
||||
return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
||||
EmitExtractReal(lhs_value),
|
||||
EmitExtractReal(rhs_value), b_),
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
||||
EmitExtractImag(lhs_value),
|
||||
EmitExtractImag(rhs_value), b_));
|
||||
|
||||
case HloOpcode::kCompare: {
|
||||
switch (op->comparison_direction()) {
|
||||
case ComparisonDirection::kEq:
|
||||
return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
||||
EmitExtractReal(lhs_value),
|
||||
EmitExtractReal(rhs_value), b_),
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
||||
EmitExtractImag(lhs_value),
|
||||
EmitExtractImag(rhs_value), b_));
|
||||
case ComparisonDirection::kNe:
|
||||
return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
||||
EmitExtractReal(lhs_value),
|
||||
EmitExtractReal(rhs_value), b_),
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
||||
EmitExtractImag(lhs_value),
|
||||
EmitExtractImag(rhs_value), b_));
|
||||
default:
|
||||
return Unimplemented(
|
||||
"complex comparison '%s'",
|
||||
ComparisonDirectionToString(op->comparison_direction()));
|
||||
}
|
||||
}
|
||||
case HloOpcode::kPower: {
|
||||
auto a = EmitExtractReal(lhs_value);
|
||||
auto b = EmitExtractImag(lhs_value);
|
||||
@ -1278,28 +1288,32 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
||||
return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
|
||||
case HloOpcode::kRemainder:
|
||||
return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
|
||||
case HloOpcode::kEq:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
|
||||
rhs_value, b_);
|
||||
case HloOpcode::kNe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
|
||||
rhs_value, b_);
|
||||
case HloOpcode::kLt:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
|
||||
lhs_value, rhs_value, b_);
|
||||
case HloOpcode::kGt:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
|
||||
lhs_value, rhs_value, b_);
|
||||
case HloOpcode::kLe:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
|
||||
lhs_value, rhs_value, b_);
|
||||
case HloOpcode::kGe:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
||||
lhs_value, rhs_value, b_);
|
||||
case HloOpcode::kCompare: {
|
||||
switch (op->comparison_direction()) {
|
||||
case ComparisonDirection::kEq:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
|
||||
rhs_value, b_);
|
||||
case ComparisonDirection::kNe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
|
||||
rhs_value, b_);
|
||||
case ComparisonDirection::kLt:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
|
||||
lhs_value, rhs_value, b_);
|
||||
case ComparisonDirection::kGt:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
|
||||
lhs_value, rhs_value, b_);
|
||||
case ComparisonDirection::kLe:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
|
||||
lhs_value, rhs_value, b_);
|
||||
case ComparisonDirection::kGe:
|
||||
return llvm_ir::EmitComparison(
|
||||
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
||||
lhs_value, rhs_value, b_);
|
||||
}
|
||||
}
|
||||
case HloOpcode::kMinimum:
|
||||
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
|
||||
case HloOpcode::kMaximum:
|
||||
@ -2197,17 +2211,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kPower:
|
||||
|
@ -81,8 +81,9 @@ class FlattenCallGraphTest : public HloTestBase {
|
||||
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
|
||||
zero, ComparisonDirection::kGt));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
@ -158,9 +159,9 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
|
||||
0, ShapeUtil::MakeShape(PRED, {}), "param0"));
|
||||
HloInstruction* false_constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
||||
HloOpcode::kEq, param0, false_constant));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), param0, false_constant,
|
||||
ComparisonDirection::kEq));
|
||||
cond_computation = module->AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ TEST_F(GpuFusibleTest,
|
||||
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
|
||||
c0 = f16[] constant(0)
|
||||
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
||||
greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
|
||||
greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
|
||||
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
||||
}
|
||||
fused_reduce {
|
||||
@ -122,7 +122,7 @@ TEST_F(GpuFusibleTest,
|
||||
p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
|
||||
c0 = f16[] constant(0)
|
||||
broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={}
|
||||
greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast)
|
||||
greater-than = pred[128,1024,32,32]{3,2,1,0} compare(p1.1, broadcast), direction=GT
|
||||
select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast)
|
||||
ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select)
|
||||
}
|
||||
@ -507,7 +507,7 @@ TEST_F(GpuFusibleTest,
|
||||
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
c0 = f32[] constant(0)
|
||||
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
|
||||
greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
|
||||
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
|
||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
|
||||
}
|
||||
|
@ -374,7 +374,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = f32[] parameter(2)
|
||||
p.1.rhs = f32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort {
|
||||
|
@ -437,7 +437,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
|
||||
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||
c0 = f32[] constant(0)
|
||||
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
|
||||
greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
|
||||
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
|
||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
|
||||
}
|
||||
@ -505,7 +505,7 @@ TEST_F(MultiOutputFusionTest,
|
||||
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
|
||||
c0 = f16[] constant(0)
|
||||
broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
|
||||
greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast)
|
||||
greater-than = pred[2,2,2]{2,1,0} compare(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast), direction=GT
|
||||
p0.1 = f16[2,2,2]{2,1,0} parameter(0)
|
||||
ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
|
||||
}
|
||||
@ -548,7 +548,7 @@ TEST_F(MultiOutputFusionTest,
|
||||
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
|
||||
c0 = f16[] constant(0)
|
||||
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
||||
greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
|
||||
greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
|
||||
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
||||
}
|
||||
fused_reduce {
|
||||
|
@ -48,8 +48,9 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) {
|
||||
HloInstruction::CreateParameter(0, param_shape, "x"));
|
||||
HloInstruction* param_y = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, param_shape, "y"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {5, 7, 2}), param_x, param_y,
|
||||
ComparisonDirection::kGe));
|
||||
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
@ -73,7 +74,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) {
|
||||
x = f32[5,7,2]{2,1,0} parameter(0)
|
||||
y = f32[5,14]{1,0} parameter(1)
|
||||
reshape = f32[5,7,2]{2,1,0} reshape(y)
|
||||
ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape)
|
||||
ROOT gte = pred[5,7,2]{2,1,0} compare(x, reshape), direction=GE
|
||||
})",
|
||||
config)
|
||||
.ValueOrDie();
|
||||
@ -98,7 +99,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) {
|
||||
y = f32[14]{0} parameter(1)
|
||||
reshape = f32[7,2]{1,0} reshape(y)
|
||||
broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2}
|
||||
ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast)
|
||||
ROOT gte = pred[5,7,2]{2,1,0} compare(x, broadcast), direction=GE
|
||||
})",
|
||||
config)
|
||||
.ValueOrDie();
|
||||
|
@ -44,9 +44,9 @@ class WhileTransformerTest : public HloTestBase {
|
||||
auto induction_variable =
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
limit_const->shape(), loop_state, tuple_index));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
|
||||
induction_variable, limit_const));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
condition_result_shape_, induction_variable, limit_const,
|
||||
ComparisonDirection::kLt));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
|
@ -54,8 +54,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
|
||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
||||
HloOpcode::kLt, cond_iter, cond_data));
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||
cond_data, ComparisonDirection::kLt));
|
||||
HloComputation* cond_computation =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
@ -113,7 +113,8 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
|
||||
// %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
|
||||
// %reshape = f32[] reshape(f32[1]{0} %slice)
|
||||
// %constant = f32[] constant(0)
|
||||
// ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant)
|
||||
// ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
|
||||
// direction=NE
|
||||
// }
|
||||
|
||||
// ENTRY %SubcomputationAccounting () -> f32[2,4] {
|
||||
@ -143,9 +144,9 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
|
||||
cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
|
||||
HloInstruction* zero = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
HloInstruction* cond_comparison =
|
||||
cond_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero));
|
||||
HloInstruction* cond_comparison = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
|
||||
zero, ComparisonDirection::kNe));
|
||||
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
// param - 1
|
||||
@ -703,8 +704,8 @@ TEST_F(HeapSimulatorTest, WholeModule) {
|
||||
HloInstruction* cond_data = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
||||
HloOpcode::kLt, cond_iter, cond_data));
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||
cond_data, ComparisonDirection::kLt));
|
||||
HloComputation* cond_computation =
|
||||
tracker.module()->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 63
|
||||
// Next ID: 64
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -146,6 +146,9 @@ message HloInstructionProto {
|
||||
// FFT length.
|
||||
repeated int64 fft_length = 32;
|
||||
|
||||
// Comparison direction only used for kCompare.
|
||||
string comparison_direction = 63;
|
||||
|
||||
// Gather dimension numbers.
|
||||
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
|
||||
repeated int64 gather_slice_sizes = 34;
|
||||
|
@ -509,8 +509,9 @@ TEST_F(HloComputationTest, CloneWithReplacements) {
|
||||
HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
|
||||
auto param2 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
|
||||
auto lt = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1));
|
||||
auto lt = builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
|
||||
param1, ComparisonDirection::kLt));
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation =
|
||||
module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));
|
||||
|
@ -42,6 +42,18 @@ StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
|
||||
HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
HloComputation* computation = lhs->parent();
|
||||
CHECK_EQ(computation, rhs->parent());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape binary_op_shape,
|
||||
ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
|
||||
HloInstruction* padding_value,
|
||||
const PaddingConfig& padding_config) {
|
||||
|
@ -32,6 +32,12 @@ namespace xla {
|
||||
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
|
||||
HloInstruction* rhs);
|
||||
|
||||
// Creates a compare HLO instruction and adds it to the computation containing
|
||||
// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
|
||||
StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs);
|
||||
|
||||
// Creates a pad HLO instruction and adds it to the computation containing
|
||||
// `operand` and `padding_value` (`operand` and `padding_value` must be in the
|
||||
// same computation).
|
||||
|
@ -2239,8 +2239,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
|
||||
HloInstruction::CreateParameter(0, in_shape, "param0"));
|
||||
auto param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, in_shape, "param1"));
|
||||
auto result = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
|
||||
auto result = builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
out_shape, param0, param1, ComparisonDirection::kEq));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
@ -2563,8 +2563,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
||||
auto data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
|
||||
return builder.Build();
|
||||
};
|
||||
|
||||
|
@ -223,8 +223,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
|
||||
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
||||
auto constant = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||
cond_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
|
||||
constant, ComparisonDirection::kLt));
|
||||
}
|
||||
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
|
@ -56,43 +56,40 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
template <typename OperandT>
|
||||
StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
|
||||
StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
|
||||
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
|
||||
std::function<bool(OperandT, OperandT)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
switch (direction) {
|
||||
case ComparisonDirection::kEq:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el == rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kNe:
|
||||
case ComparisonDirection::kNe:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el != rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kGe:
|
||||
case ComparisonDirection::kGe:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el >= rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kGt:
|
||||
case ComparisonDirection::kGt:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el > rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kLe:
|
||||
case ComparisonDirection::kLe:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el <= rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kLt:
|
||||
case ComparisonDirection::kLt:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el < rhs_el;
|
||||
};
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
|
||||
Literal result(shape);
|
||||
@ -106,24 +103,25 @@ StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
|
||||
}
|
||||
|
||||
template <>
|
||||
StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
|
||||
StatusOr<Literal> Compare<complex64>(const Shape& shape,
|
||||
ComparisonDirection direction,
|
||||
LiteralSlice lhs_literal,
|
||||
LiteralSlice rhs_literal) {
|
||||
std::function<bool(complex64, complex64)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
switch (direction) {
|
||||
case ComparisonDirection::kEq:
|
||||
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||
return lhs_el == rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kNe:
|
||||
case ComparisonDirection::kNe:
|
||||
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||
return lhs_el != rhs_el;
|
||||
};
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||
<< HloOpcodeString(opcode);
|
||||
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
|
||||
<< ComparisonDirectionToString(direction);
|
||||
}
|
||||
|
||||
Literal result(shape);
|
||||
@ -137,24 +135,25 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
|
||||
}
|
||||
|
||||
template <>
|
||||
StatusOr<Literal> Compare<complex128>(const Shape& shape, HloOpcode opcode,
|
||||
StatusOr<Literal> Compare<complex128>(const Shape& shape,
|
||||
ComparisonDirection direction,
|
||||
LiteralSlice lhs_literal,
|
||||
LiteralSlice rhs_literal) {
|
||||
std::function<bool(complex128, complex128)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
switch (direction) {
|
||||
case ComparisonDirection::kEq:
|
||||
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
||||
return lhs_el == rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kNe:
|
||||
case ComparisonDirection::kNe:
|
||||
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
||||
return lhs_el != rhs_el;
|
||||
};
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||
<< HloOpcodeString(opcode);
|
||||
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
|
||||
<< ComparisonDirectionToString(direction);
|
||||
}
|
||||
|
||||
Literal result(shape);
|
||||
@ -671,7 +670,7 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) {
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
|
||||
HloOpcode opcode = compare->opcode();
|
||||
ComparisonDirection direction = compare->comparison_direction();
|
||||
auto lhs = compare->operand(0);
|
||||
auto rhs = compare->operand(1);
|
||||
DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
|
||||
@ -687,76 +686,76 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
|
||||
case PRED: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case U8: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<uint8>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case U16: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<uint16>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<uint16>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case U32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<uint32>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case U64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<uint64>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case S8: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case S16: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<int16>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<int16>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case S32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<int32>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case S64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<int64>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case F16: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<half>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case BF16: {
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<bfloat16>(compare->shape(), opcode,
|
||||
Compare<bfloat16>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case F32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<float>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case F64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<double>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case C64: {
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<complex64>(compare->shape(), opcode,
|
||||
Compare<complex64>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case C128: {
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<complex128>(compare->shape(), opcode,
|
||||
Compare<complex128>(compare->shape(), direction,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
default:
|
||||
|
@ -2848,8 +2848,15 @@ TEST_F(HloEvaluatorTest, DoesCompareBF16) {
|
||||
{bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
|
||||
auto expected =
|
||||
LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
|
||||
TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
HloComputation::Builder b(TestName());
|
||||
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
|
||||
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
|
||||
b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
|
||||
ComparisonDirection::kGe));
|
||||
m_->AddEntryComputation(b.Build());
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate()));
|
||||
}
|
||||
|
||||
TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
|
||||
|
@ -258,14 +258,16 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
|
||||
// param0), check that the operation being performed is commutative.
|
||||
if (root->operand(0) == param1) {
|
||||
CHECK_EQ(root->operand(1), param0);
|
||||
switch (root->opcode()) {
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLt:
|
||||
return nullopt;
|
||||
default:
|
||||
break;
|
||||
if (root->opcode() == HloOpcode()) {
|
||||
switch (root->comparison_direction()) {
|
||||
case ComparisonDirection::kLe:
|
||||
case ComparisonDirection::kGe:
|
||||
case ComparisonDirection::kGt:
|
||||
case ComparisonDirection::kLt:
|
||||
return nullopt;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,18 +281,22 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
|
||||
return "min";
|
||||
case HloOpcode::kMaximum:
|
||||
return "max";
|
||||
case HloOpcode::kLe:
|
||||
return "less-or-equal";
|
||||
case HloOpcode::kGe:
|
||||
return "greater-or-equal";
|
||||
case HloOpcode::kGt:
|
||||
return "greater-than";
|
||||
case HloOpcode::kLt:
|
||||
return "less-than";
|
||||
case HloOpcode::kEq:
|
||||
return "equal-to";
|
||||
case HloOpcode::kNe:
|
||||
return "not-equal-to";
|
||||
case HloOpcode::kCompare: {
|
||||
switch (root->comparison_direction()) {
|
||||
case ComparisonDirection::kLe:
|
||||
return "less-or-equal";
|
||||
case ComparisonDirection::kGe:
|
||||
return "greater-or-equal";
|
||||
case ComparisonDirection::kGt:
|
||||
return "greater-than";
|
||||
case ComparisonDirection::kLt:
|
||||
return "less-than";
|
||||
case ComparisonDirection::kEq:
|
||||
return "equal-to";
|
||||
case ComparisonDirection::kNe:
|
||||
return "not-equal-to";
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nullopt;
|
||||
}
|
||||
@ -922,27 +928,22 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kExpm1:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIota:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
|
||||
@ -31,6 +32,8 @@ namespace {
|
||||
using absl::StrCat;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
using HloGraphDumperTest = HloTestBase;
|
||||
|
||||
string TestName() {
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
@ -48,7 +51,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
|
||||
|
||||
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
|
||||
|
||||
TEST(HloGraphDumperTest, NestedFusion) {
|
||||
TEST_F(HloGraphDumperTest, NestedFusion) {
|
||||
HloComputation::Builder b("b");
|
||||
|
||||
// Build param0 + param1 + param2 + param3 + param4.
|
||||
@ -118,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) {
|
||||
HasSubstr(inner_sum->name()));
|
||||
}
|
||||
|
||||
TEST(HloGraphDumperTest, Constant) {
|
||||
TEST_F(HloGraphDumperTest, Constant) {
|
||||
HloComputation::Builder b("b");
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
|
||||
@ -132,7 +135,7 @@ TEST(HloGraphDumperTest, Constant) {
|
||||
EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
|
||||
}
|
||||
|
||||
TEST(HloGraphDumperTest, TupleConstant) {
|
||||
TEST_F(HloGraphDumperTest, TupleConstant) {
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
|
||||
HloComputation::Builder b("b");
|
||||
@ -150,5 +153,21 @@ TEST(HloGraphDumperTest, TupleConstant) {
|
||||
EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
|
||||
}
|
||||
|
||||
TEST_F(HloGraphDumperTest, Compare) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule comp
|
||||
|
||||
ENTRY comp {
|
||||
param.0 = f32[10] parameter(0)
|
||||
param.1 = f32[10] parameter(1)
|
||||
ROOT lt = pred[10] compare(param.0, param.1), direction=LT
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
string graph = hlo_graph_dumper::DumpGraph(*module->entry_computation(),
|
||||
/*label=*/"comp", DebugOptions());
|
||||
EXPECT_THAT(graph, HasSubstr("direction=LT"));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
||||
|
@ -64,7 +64,35 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
||||
TF_RET_CHECK(!proto.opcode().empty());
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
|
||||
HloOpcode opcode;
|
||||
auto opcode_or = StringToHloOpcode(proto.opcode());
|
||||
absl::optional<ComparisonDirection> comparison_direction;
|
||||
if (opcode_or.ok()) {
|
||||
opcode = opcode_or.ConsumeValueOrDie();
|
||||
} else {
|
||||
// Unknown opcode. Try auto-upgrading deprecated "less-than",
|
||||
// "greater-than", etc opcodes, which are now rolled into the kCompare
|
||||
// opcode.
|
||||
if (proto.opcode() == "equal-to") {
|
||||
comparison_direction = ComparisonDirection::kEq;
|
||||
} else if (proto.opcode() == "not-equal-to") {
|
||||
comparison_direction = ComparisonDirection::kNe;
|
||||
} else if (proto.opcode() == "greater-than-or-equal-to") {
|
||||
comparison_direction = ComparisonDirection::kGe;
|
||||
} else if (proto.opcode() == "greater-than") {
|
||||
comparison_direction = ComparisonDirection::kGt;
|
||||
} else if (proto.opcode() == "less-than-or-equal-to") {
|
||||
comparison_direction = ComparisonDirection::kLe;
|
||||
} else if (proto.opcode() == "less-than") {
|
||||
comparison_direction = ComparisonDirection::kLt;
|
||||
}
|
||||
if (comparison_direction) {
|
||||
opcode = HloOpcode::kCompare;
|
||||
} else {
|
||||
return InvalidArgument("Unknown opcode: %s", proto.opcode());
|
||||
}
|
||||
}
|
||||
|
||||
TF_RET_CHECK(proto.has_shape());
|
||||
|
||||
std::unique_ptr<HloInstruction> instruction;
|
||||
@ -136,6 +164,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
absl::Span<const int64>(fft_length));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCompare: {
|
||||
// Auto-upgraded from deprecated opcode skips the following.
|
||||
if (!comparison_direction) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
comparison_direction,
|
||||
StringToComparisonDirection(proto.comparison_direction()));
|
||||
}
|
||||
instruction =
|
||||
CreateCompare(shape, operands(0), operands(1), *comparison_direction);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTriangularSolve: {
|
||||
instruction = CreateTriangularSolve(shape, operands(0), operands(1),
|
||||
proto.triangular_solve_options());
|
||||
@ -688,15 +727,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSubtract:
|
||||
@ -761,6 +794,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
||||
fft_length);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
ComparisonDirection direction) {
|
||||
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
|
||||
HloInstruction* b,
|
||||
@ -1311,6 +1350,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kFft:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
@ -1384,12 +1424,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kPower:
|
||||
@ -1705,26 +1739,20 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kExpm1:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kReal:
|
||||
@ -1772,6 +1800,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kFft:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kRecv:
|
||||
@ -2119,17 +2148,12 @@ bool HloInstruction::IsElementwiseImpl(
|
||||
// Binary elementwise operations, the same as in IsElementwiseBinary().
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSubtract:
|
||||
@ -2472,12 +2496,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleGetTupleElement(this);
|
||||
case HloOpcode::kParameter:
|
||||
return visitor->HandleParameter(this);
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kCompare:
|
||||
return visitor->HandleCompare(this);
|
||||
case HloOpcode::kComplex:
|
||||
return visitor->HandleComplex(this);
|
||||
@ -3519,6 +3538,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const {
|
||||
return Cast<HloDomainInstruction>(this)->user_side_metadata();
|
||||
}
|
||||
|
||||
ComparisonDirection HloInstruction::comparison_direction() const {
|
||||
return Cast<HloCompareInstruction>(this)->direction();
|
||||
}
|
||||
|
||||
const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
|
||||
return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/iterator_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
@ -444,6 +445,11 @@ class HloInstruction {
|
||||
const Shape& shape, HloInstruction* operand, FftType fft_type,
|
||||
absl::Span<const int64> fft_length);
|
||||
|
||||
// Creates a compare op, performing the comparison specified in direction.
|
||||
static std::unique_ptr<HloInstruction> CreateCompare(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
ComparisonDirection direction);
|
||||
|
||||
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
|
||||
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
||||
const TriangularSolveOptions& options);
|
||||
@ -1600,6 +1606,9 @@ class HloInstruction {
|
||||
// Delegates to HloDomainInstruction::user_side_metadata().
|
||||
const DomainMetadata& user_side_metadata() const;
|
||||
|
||||
// Delegates to HloCompareInstruction::direction().
|
||||
ComparisonDirection comparison_direction() const;
|
||||
|
||||
// Delegates to HloTriangularSolveInstruction::triangular_solve_options().
|
||||
const TriangularSolveOptions& triangular_solve_options() const;
|
||||
|
||||
|
@ -1655,7 +1655,7 @@ body (bparam: s32[]) -> s32[] {
|
||||
condition (cparam: s32[]) -> pred[] {
|
||||
xconstant = s32[] constant(5)
|
||||
cparam = s32[] parameter(0)
|
||||
ROOT greater-than = pred[] greater-than(xconstant, cparam)
|
||||
ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
|
||||
}
|
||||
|
||||
ENTRY entry (param: s32[]) -> s32[] {
|
||||
|
@ -202,6 +202,42 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
|
||||
fft_length_);
|
||||
}
|
||||
|
||||
HloCompareInstruction::HloCompareInstruction(const Shape& shape,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs,
|
||||
ComparisonDirection direction)
|
||||
: HloInstruction(HloOpcode::kCompare, shape), direction_(direction) {
|
||||
AppendOperand(lhs);
|
||||
AppendOperand(rhs);
|
||||
}
|
||||
|
||||
HloInstructionProto HloCompareInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
proto.set_comparison_direction(ComparisonDirectionToString(direction_));
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
return {StrCat("direction=", ComparisonDirectionToString(direction()))};
|
||||
}
|
||||
|
||||
bool HloCompareInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
|
||||
return direction() == casted_other.direction();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
CHECK_EQ(new_operands.size(), 2);
|
||||
return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
|
||||
new_operands[1], direction());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
|
||||
|
@ -131,6 +131,28 @@ class HloFftInstruction : public HloInstruction {
|
||||
std::vector<int64> fft_length_;
|
||||
};
|
||||
|
||||
class HloCompareInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
|
||||
HloInstruction* rhs,
|
||||
ComparisonDirection direction);
|
||||
ComparisonDirection direction() const { return direction_; }
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
ComparisonDirection direction_;
|
||||
};
|
||||
|
||||
class HloTriangularSolveInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
|
||||
|
@ -255,7 +255,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
@ -308,7 +308,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
|
||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
|
||||
add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(add.1, constant.2)
|
||||
ROOT less-than = pred[] compare(add.1, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
@ -360,7 +360,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
|
||||
loop_var.2 = (s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.1 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.2 = s32[] constant(0)
|
||||
@ -415,7 +415,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
|
||||
cond_param = (s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||
constant.2 = s32[] constant(10)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
@ -448,13 +448,13 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
|
||||
cond_param = (s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||
constant.2 = s32[] constant(10)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
OuterWhileCondition {
|
||||
cond_param.2 = (s32[]) parameter(0)
|
||||
get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
|
||||
constant.5 = s32[] constant(5)
|
||||
ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
|
||||
ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT
|
||||
}
|
||||
OuterWhileBody {
|
||||
body_param.2 = (s32[]) parameter(0)
|
||||
|
@ -89,6 +89,22 @@ bool HloParameterMatcher::MatchAndExplain(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloComparisonMatcher::MatchAndExplain(
|
||||
const HloInstruction* instruction,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
|
||||
return false;
|
||||
}
|
||||
if (instruction->comparison_direction() != direction_) {
|
||||
*listener << "has wrong comparison direction (got "
|
||||
<< ComparisonDirectionToString(
|
||||
instruction->comparison_direction())
|
||||
<< ", want " << ComparisonDirectionToString(direction_) << ")";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloGetTupleElementMatcher::MatchAndExplain(
|
||||
const HloInstruction* instruction,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
|
@ -54,6 +54,21 @@ class HloParameterMatcher : public HloMatcher {
|
||||
int64 parameter_number_;
|
||||
};
|
||||
|
||||
// Custom matcher for comparisons, which accepts a comparison direction.
|
||||
class HloComparisonMatcher : public HloMatcher {
|
||||
public:
|
||||
explicit HloComparisonMatcher(
|
||||
ComparisonDirection direction,
|
||||
std::vector<::testing::Matcher<const HloInstruction*>> operands)
|
||||
: HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
|
||||
|
||||
bool MatchAndExplain(const HloInstruction* instruction,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
|
||||
private:
|
||||
ComparisonDirection direction_;
|
||||
};
|
||||
|
||||
// Custom matcher for get-tuple-element instructions, which accepts a tuple
|
||||
// index to match.
|
||||
class HloGetTupleElementMatcher : public HloMatcher {
|
||||
@ -172,6 +187,7 @@ HLO_MATCHER(BatchNormGrad);
|
||||
HLO_MATCHER(Call);
|
||||
HLO_MATCHER(Ceil);
|
||||
HLO_MATCHER(Clamp);
|
||||
HLO_MATCHER(Compare);
|
||||
HLO_MATCHER(Concatenate);
|
||||
HLO_MATCHER(Conditional);
|
||||
HLO_MATCHER(Constant);
|
||||
@ -184,28 +200,22 @@ HLO_MATCHER(Divide);
|
||||
HLO_MATCHER(Domain);
|
||||
HLO_MATCHER(DynamicSlice);
|
||||
HLO_MATCHER(DynamicUpdateSlice);
|
||||
HLO_MATCHER(Eq);
|
||||
HLO_MATCHER(Exp);
|
||||
HLO_MATCHER(Floor);
|
||||
HLO_MATCHER(Fusion);
|
||||
HLO_MATCHER(Ge);
|
||||
HLO_MATCHER(AfterAll);
|
||||
HLO_MATCHER(Gt);
|
||||
HLO_MATCHER(Iota);
|
||||
HLO_MATCHER(Infeed);
|
||||
HLO_MATCHER(IsFinite);
|
||||
HLO_MATCHER(Le);
|
||||
HLO_MATCHER(Log);
|
||||
HLO_MATCHER(And);
|
||||
HLO_MATCHER(Not);
|
||||
HLO_MATCHER(Or);
|
||||
HLO_MATCHER(Xor);
|
||||
HLO_MATCHER(Lt);
|
||||
HLO_MATCHER(Map);
|
||||
HLO_MATCHER(Maximum);
|
||||
HLO_MATCHER(Minimum);
|
||||
HLO_MATCHER(Multiply);
|
||||
HLO_MATCHER(Ne);
|
||||
HLO_MATCHER(Negate);
|
||||
HLO_MATCHER(Outfeed);
|
||||
HLO_MATCHER(Pad);
|
||||
@ -256,6 +266,38 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
|
||||
new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
|
||||
}
|
||||
|
||||
// Comparison matchers below do not require any additional arguments.
|
||||
template <typename... M>
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||
ComparisonDirection::kEq, {operands...}));
|
||||
}
|
||||
template <typename... M>
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||
ComparisonDirection::kNe, {operands...}));
|
||||
}
|
||||
template <typename... M>
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||
ComparisonDirection::kGe, {operands...}));
|
||||
}
|
||||
template <typename... M>
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||
ComparisonDirection::kGt, {operands...}));
|
||||
}
|
||||
template <typename... M>
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||
ComparisonDirection::kLe, {operands...}));
|
||||
}
|
||||
template <typename... M>
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
|
||||
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||
ComparisonDirection::kLt, {operands...}));
|
||||
}
|
||||
|
||||
// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
|
||||
// tuple element of operand, while GetTupleElement(operand) matches any GTE
|
||||
// operation on operand, and GetTupleElement() matches any GTE operation at all.
|
||||
|
@ -220,5 +220,33 @@ ENTRY DotOperationFusion_TransposeFusion {
|
||||
"rhs_contracting_dimensions (got {0} want {1})");
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, ComparisonMatcher) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {1});
|
||||
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
|
||||
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
|
||||
auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||
ComparisonDirection::kEq);
|
||||
auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||
ComparisonDirection::kNe);
|
||||
auto add =
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
|
||||
auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
|
||||
ComparisonDirection::kLe);
|
||||
|
||||
EXPECT_THAT(eq.get(), op::Compare());
|
||||
EXPECT_THAT(eq.get(), op::Eq());
|
||||
EXPECT_THAT(ne.get(), op::Compare());
|
||||
EXPECT_THAT(ne.get(), op::Ne());
|
||||
EXPECT_THAT(le.get(),
|
||||
op::Compare(op::Parameter(0),
|
||||
op::Add(op::Parameter(0), op::Parameter(1))));
|
||||
EXPECT_THAT(le.get(), op::Le(op::Parameter(0),
|
||||
op::Add(op::Parameter(0), op::Parameter(1))));
|
||||
|
||||
EXPECT_THAT(Explain(eq.get(), op::Add()), Eq(""));
|
||||
EXPECT_THAT(Explain(eq.get(), op::Ne()),
|
||||
Eq("has wrong comparison direction (got EQ, want NE)"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -254,8 +254,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
||||
HloInstruction* zero_vector =
|
||||
cond_builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({0, 0, 0, 0})));
|
||||
cond_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param,
|
||||
zero_vector, ComparisonDirection::kNe));
|
||||
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
// param - 1
|
||||
|
@ -86,7 +86,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
@ -125,7 +125,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
|
||||
loop_var.2 = (s32[], f32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.3 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.4 = s32[] constant(0)
|
||||
@ -163,7 +163,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
@ -206,7 +206,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
|
||||
loop_var.2 = (s32[], s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
@ -248,7 +248,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
SimpleLoop.body1 {
|
||||
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
||||
@ -263,7 +263,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
||||
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
||||
constant.4 = s32[] constant(5)
|
||||
ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
|
||||
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.5 = s32[] constant(0)
|
||||
@ -316,7 +316,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
||||
loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
|
||||
constant.2 = s32[] constant(5)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
SimpleLoop.body1 {
|
||||
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
||||
@ -331,7 +331,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
||||
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
||||
constant.4 = s32[] constant(5)
|
||||
ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
|
||||
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.5 = s32[] constant(0)
|
||||
@ -383,7 +383,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) {
|
||||
cond_param = (s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||
constant.2 = s32[] constant(10)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
@ -418,7 +418,7 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
|
||||
cond_param = (s32[], s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||
constant.2 = s32[] constant(10)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
p0 = (s32[]) parameter(0)
|
||||
|
@ -44,21 +44,8 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
#define CHECK_DEFAULT(property_name, opcode_name) false
|
||||
#define CHECK_PROPERTY(property_name, opcode_name, value) \
|
||||
(value & property_name)
|
||||
#define RESOLVE(_1, _2, target, ...) target
|
||||
#define HAS_PROPERTY(property, ...) \
|
||||
RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__)
|
||||
|
||||
bool HloOpcodeIsComparison(HloOpcode opcode) {
|
||||
switch (opcode) {
|
||||
#define CASE_IS_COMPARISON(enum_name, opcode_name, ...) \
|
||||
case HloOpcode::enum_name: \
|
||||
return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__);
|
||||
HLO_OPCODE_LIST(CASE_IS_COMPARISON)
|
||||
#undef CASE_IS_COMPARISON
|
||||
}
|
||||
return opcode == HloOpcode::kCompare;
|
||||
}
|
||||
|
||||
bool HloOpcodeIsVariadic(HloOpcode opcode) {
|
||||
@ -82,9 +69,4 @@ absl::optional<int> HloOpcodeArity(HloOpcode opcode) {
|
||||
}
|
||||
}
|
||||
|
||||
#undef HAS_PROPERTY
|
||||
#undef RESOLVE
|
||||
#undef CHECK_DEFAULT
|
||||
#undef CHECK_PROPERTY
|
||||
|
||||
} // namespace xla
|
||||
|
@ -19,8 +19,10 @@ limitations under the License.
|
||||
#include <iosfwd>
|
||||
#include <string>
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -65,6 +67,7 @@ namespace xla {
|
||||
V(kClamp, "clamp", 3) \
|
||||
V(kCollectivePermute, "collective-permute", 1) \
|
||||
V(kClz, "count-leading-zeros", 1) \
|
||||
V(kCompare, "compare", 2) \
|
||||
V(kComplex, "complex", 2) \
|
||||
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
|
||||
V(kConditional, "conditional", kHloOpcodeIsVariadic) \
|
||||
@ -79,34 +82,28 @@ namespace xla {
|
||||
V(kDot, "dot", 2) \
|
||||
V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \
|
||||
V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \
|
||||
V(kEq, "equal-to", 2, kHloOpcodeIsComparison) \
|
||||
V(kExp, "exponential", 1) \
|
||||
V(kExpm1, "exponential-minus-one", 1) \
|
||||
V(kFft, "fft", 1) \
|
||||
V(kFloor, "floor", 1) \
|
||||
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
|
||||
V(kGather, "gather", 2) \
|
||||
V(kGe, "greater-than-or-equal-to", 2, kHloOpcodeIsComparison) \
|
||||
V(kGetDimensionSize, "get-dimension-size", 1) \
|
||||
V(kGetTupleElement, "get-tuple-element", 1) \
|
||||
V(kGt, "greater-than", 2, kHloOpcodeIsComparison) \
|
||||
V(kImag, "imag", 1) \
|
||||
V(kInfeed, "infeed", 1) \
|
||||
V(kIota, "iota", 0) \
|
||||
V(kIsFinite, "is-finite", 1) \
|
||||
V(kLe, "less-than-or-equal-to", 2, kHloOpcodeIsComparison) \
|
||||
V(kLog, "log", 1) \
|
||||
V(kLog1p, "log-plus-one", 1) \
|
||||
V(kAnd, "and", 2) \
|
||||
V(kNot, "not", 1) \
|
||||
V(kOr, "or", 2) \
|
||||
V(kXor, "xor", 2) \
|
||||
V(kLt, "less-than", 2, kHloOpcodeIsComparison) \
|
||||
V(kMap, "map", kHloOpcodeIsVariadic) \
|
||||
V(kMaximum, "maximum", 2) \
|
||||
V(kMinimum, "minimum", 2) \
|
||||
V(kMultiply, "multiply", 2) \
|
||||
V(kNe, "not-equal-to", 2, kHloOpcodeIsComparison) \
|
||||
V(kNegate, "negate", 1) \
|
||||
V(kOutfeed, "outfeed", 2) \
|
||||
V(kPad, "pad", 2) \
|
||||
|
@ -42,12 +42,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
|
||||
|
||||
// Test some properties.
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kCompare:
|
||||
EXPECT_TRUE(HloOpcodeIsComparison(opcode));
|
||||
break;
|
||||
default:
|
||||
|
@ -306,7 +306,7 @@ condition.v4 {
|
||||
constant.2 = s32[] constant(2)
|
||||
prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
|
||||
get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
|
||||
ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8)
|
||||
ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT
|
||||
}
|
||||
|
||||
fused_computation {
|
||||
|
@ -183,6 +183,7 @@ class HloParser {
|
||||
kHloComputation,
|
||||
kBracedHloComputationList,
|
||||
kFftType,
|
||||
kComparisonDirection,
|
||||
kWindow,
|
||||
kConvolutionDimensionNumbers,
|
||||
kSharding,
|
||||
@ -300,6 +301,7 @@ class HloParser {
|
||||
bool ParseTiles(std::vector<Tile>* tiles);
|
||||
bool ParseOpcode(HloOpcode* result);
|
||||
bool ParseFftType(FftType* result);
|
||||
bool ParseComparisonDirection(ComparisonDirection* result);
|
||||
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
||||
bool ParseRandomDistribution(RandomDistribution* result);
|
||||
bool ParsePrecision(PrecisionConfig::Precision* result);
|
||||
@ -763,12 +765,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kPower:
|
||||
@ -1133,6 +1129,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
shape, operands[0], operands[1], options));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCompare: {
|
||||
optional<ComparisonDirection> direction;
|
||||
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
|
||||
&direction};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
|
||||
shape, operands[0], operands[1], *direction));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCholesky: {
|
||||
CholeskyOptions options;
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
@ -2728,6 +2736,15 @@ bool HloParser::ParseAttributeHelper(
|
||||
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kComparisonDirection: {
|
||||
ComparisonDirection result;
|
||||
if (!ParseComparisonDirection(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
|
||||
->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kWindow: {
|
||||
Window result;
|
||||
if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
|
||||
@ -3756,6 +3773,22 @@ bool HloParser::ParseFftType(FftType* result) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
|
||||
VLOG(1) << "ParseComparisonDirection";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects comparison direction");
|
||||
}
|
||||
string val = lexer_.GetStrVal();
|
||||
auto status_or_result = StringToComparisonDirection(val);
|
||||
if (!status_or_result.ok()) {
|
||||
return TokenError(
|
||||
StrFormat("expects comparison direction but sees: %s", val));
|
||||
}
|
||||
*result = status_or_result.ValueOrDie();
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
|
||||
VLOG(1) << "ParseFusionKind";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
|
@ -222,7 +222,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
|
||||
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
||||
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
|
||||
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
|
||||
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
|
||||
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
|
||||
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
|
||||
}
|
||||
|
||||
@ -292,7 +292,7 @@ R"(HloModule WhileWithScalarS32Result_module
|
||||
%condition.v3 (prev.2: s32[]) -> pred[] {
|
||||
%constant.1 = s32[] constant(5)
|
||||
%prev.2 = s32[] parameter(0)
|
||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2)
|
||||
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
|
||||
}
|
||||
|
||||
ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
|
||||
@ -474,7 +474,7 @@ R"(HloModule R4F32OverlapSmall_module
|
||||
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
||||
%lhs = f32[] parameter(0)
|
||||
%rhs = f32[] parameter(1)
|
||||
ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
|
||||
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
|
||||
}
|
||||
|
||||
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
||||
@ -500,7 +500,7 @@ R"(HloModule select_and_scatter_scalar
|
||||
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
||||
%lhs = f32[] parameter(0)
|
||||
%rhs = f32[] parameter(1)
|
||||
ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
|
||||
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
|
||||
}
|
||||
|
||||
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
||||
@ -1037,7 +1037,7 @@ R"(HloModule TupleReduce
|
||||
max_argmax {
|
||||
value = f32[] parameter(2)
|
||||
prev_max = f32[] parameter(0)
|
||||
is_next_larger = pred[] greater-than-or-equal-to(value, prev_max)
|
||||
is_next_larger = pred[] compare(value, prev_max), direction=GE
|
||||
max = f32[] select(is_next_larger, value, prev_max)
|
||||
index = s32[] parameter(3)
|
||||
prev_argmax = s32[] parameter(1)
|
||||
@ -1106,7 +1106,7 @@ R"(HloModule sort
|
||||
compare {
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
@ -1126,7 +1126,7 @@ compare {
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
@ -1145,7 +1145,7 @@ R"(HloModule sort
|
||||
compare {
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
@ -1165,7 +1165,7 @@ compare {
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
@ -1190,7 +1190,7 @@ compare {
|
||||
p.3.rhs = f32[] parameter(7)
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
@ -1211,7 +1211,7 @@ R"(HloModule sort
|
||||
compare {
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
@ -1469,7 +1469,7 @@ compare {
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
@ -1656,7 +1656,7 @@ TEST_F(HloParserTest, WrongOperandsSize) {
|
||||
|
||||
ENTRY %blabla (x: f32[]) -> pred[] {
|
||||
%x = f32[]{} parameter(0)
|
||||
%eq = pred[]{} equal-to(f32[]{} %x)
|
||||
%eq = pred[]{} compare(f32[]{} %x), direction=EQ
|
||||
}
|
||||
|
||||
)";
|
||||
@ -1668,7 +1668,7 @@ TEST_F(HloParserTest, OperandNotFound) {
|
||||
const string original = R"(HloModule operand_not_found:
|
||||
ENTRY %blabla (x: f32[]) -> pred[] {
|
||||
%x = f32[]{} parameter(0)
|
||||
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
|
||||
%eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
|
||||
}
|
||||
)";
|
||||
auto result = ParseHloString(original);
|
||||
|
@ -228,7 +228,7 @@ HloModule UpdateScheduleWithMultipleComputations
|
||||
%param = (s32[], token[]) parameter(0)
|
||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||
%constant = s32[] constant(42)
|
||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
||||
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %WhileLoop () -> s32[] {
|
||||
@ -297,7 +297,7 @@ HloModule UpdateScheduleWithMultipleComputations
|
||||
%param = (s32[], token[]) parameter(0)
|
||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||
%constant = s32[] constant(42)
|
||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
||||
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %WhileLoop () -> s32[] {
|
||||
|
@ -65,6 +65,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kConstant:
|
||||
@ -72,21 +73,15 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kIota:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
|
@ -2001,6 +2001,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kConditional:
|
||||
@ -2012,24 +2013,18 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kExpm1:
|
||||
case HloOpcode::kFft:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kOr:
|
||||
|
@ -1084,7 +1084,7 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
|
||||
tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
|
||||
counter.1 = s32[] get-tuple-element(tup.1), index=0
|
||||
five = s32[] constant(5)
|
||||
ROOT lt = pred[] less-than(counter.1, five)
|
||||
ROOT lt = pred[] compare(counter.1, five), direction=LT
|
||||
}
|
||||
|
||||
body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
|
||||
|
@ -46,7 +46,7 @@ condition {
|
||||
condition.state = f32[] parameter(0)
|
||||
addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget"
|
||||
add = f32[] add(addend, condition.state)
|
||||
ROOT greater-than = pred[] greater-than(const.100, add)
|
||||
ROOT greater-than = pred[] compare(const.100, add), direction=GT
|
||||
}
|
||||
|
||||
ENTRY while3 {
|
||||
|
@ -67,6 +67,7 @@ namespace xla {
|
||||
// - WithOneUse: Instruction is used as an operand exactly once.
|
||||
// - WithOneUser: Instruction is used by exactly one other instruction, but
|
||||
// is possibly used more than once as an operand (e.g. multiply(x,x)).
|
||||
// - WithComparisonDirection: instr has the given direction
|
||||
//
|
||||
// Shape():
|
||||
// - EqualTo
|
||||
@ -1671,6 +1672,40 @@ class HloInstructionPatternOneUserImpl
|
||||
}
|
||||
};
|
||||
|
||||
class HloInstructionPatternComparisonDirectionImpl {
|
||||
public:
|
||||
explicit constexpr HloInstructionPatternComparisonDirectionImpl(
|
||||
ComparisonDirection direction)
|
||||
: direction_(direction) {}
|
||||
|
||||
bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
|
||||
return MatchImpl(inst, option);
|
||||
}
|
||||
|
||||
bool Match(::xla::HloInstruction* inst, MatchOption option) const {
|
||||
return MatchImpl(inst, option);
|
||||
}
|
||||
|
||||
void DescribeTo(std::ostream* os, int64 indent = 0) const {
|
||||
*os << "which has comparison direction "
|
||||
<< ComparisonDirectionToString(direction_);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename HloInstructionType>
|
||||
bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
|
||||
if (inst->opcode() != HloOpcode::kCompare ||
|
||||
inst->comparison_direction() != direction_) {
|
||||
EXPLAIN << "HloInstruction is not comparison "
|
||||
<< ComparisonDirectionToString(direction_);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ComparisonDirection direction_;
|
||||
};
|
||||
|
||||
// Matches a constant scalar or effective scalar, optionally with a given value.
|
||||
template <typename ScalarTy>
|
||||
class HloConstantScalarImpl {
|
||||
@ -1956,6 +1991,14 @@ class HloInstructionPattern {
|
||||
return AppendImpl(HloInstructionPatternOneUserImpl());
|
||||
}
|
||||
|
||||
// Modifies the pattern to match only if the instruction has the given
|
||||
// comparison direction.
|
||||
auto WithComparisonDirection(ComparisonDirection direction) const
|
||||
-> decltype(this->AppendImpl(
|
||||
HloInstructionPatternComparisonDirectionImpl(direction))) {
|
||||
return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
|
||||
}
|
||||
|
||||
void DescribeTo(std::ostream* os, int64 indent = 0) const {
|
||||
impl_.DescribeTo(os, indent);
|
||||
}
|
||||
@ -2118,18 +2161,13 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Add)
|
||||
XLA_BINOP_PATTERN(Atan2)
|
||||
XLA_BINOP_PATTERN(Divide)
|
||||
XLA_BINOP_PATTERN(Complex)
|
||||
XLA_BINOP_PATTERN(Compare)
|
||||
XLA_BINOP_PATTERN(Convolution)
|
||||
XLA_BINOP_PATTERN(Dot)
|
||||
XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
|
||||
XLA_BINOP_PATTERN(Gather)
|
||||
XLA_BINOP_PATTERN(Ge)
|
||||
XLA_BINOP_PATTERN(Gt)
|
||||
XLA_BINOP_PATTERN(Le)
|
||||
XLA_BINOP_PATTERN(Lt)
|
||||
XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
|
||||
XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
|
||||
XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
|
||||
XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
|
||||
XLA_BINOP_PATTERN(Outfeed)
|
||||
XLA_BINOP_PATTERN(Pad)
|
||||
XLA_BINOP_PATTERN(Power)
|
||||
@ -2242,6 +2280,73 @@ XLA_VARIADIC_OP_PATTERN(Reduce);
|
||||
XLA_VARIADIC_OP_PATTERN(Sort);
|
||||
XLA_VARIADIC_OP_PATTERN(Tuple);
|
||||
|
||||
// Helpers for comparison instructions.
|
||||
#define XLA_COMPARE_PATTERN(NAME) \
|
||||
inline auto NAME()->decltype( \
|
||||
Op().WithOpcode(HloOpcode::kCompare) \
|
||||
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
|
||||
return Op() \
|
||||
.WithOpcode(HloOpcode::kCompare) \
|
||||
.WithComparisonDirection(ComparisonDirection::k##NAME); \
|
||||
} \
|
||||
\
|
||||
template <typename Lhs, typename Rhs> \
|
||||
inline auto NAME(Lhs&& lhs, Rhs&& rhs) \
|
||||
->decltype(Op().WithOpcode(HloOpcode::kCompare) \
|
||||
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
|
||||
return Op() \
|
||||
.WithOpcode(HloOpcode::kCompare) \
|
||||
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||
.WithComparisonDirection(ComparisonDirection::k##NAME); \
|
||||
} \
|
||||
\
|
||||
template <typename HloInstructionType, typename Lhs, typename Rhs> \
|
||||
inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
|
||||
->decltype(Op(matched_inst) \
|
||||
.WithOpcode(HloOpcode::kCompare) \
|
||||
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
|
||||
return Op(matched_inst) \
|
||||
.WithOpcode(HloOpcode::kCompare) \
|
||||
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||
.WithComparisonDirection(ComparisonDirection::k##NAME); \
|
||||
}
|
||||
|
||||
#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
|
||||
XLA_COMPARE_PATTERN(NAME) \
|
||||
\
|
||||
template <typename HloInstructionType, typename Lhs, typename Rhs> \
|
||||
inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
|
||||
Rhs&& rhs) \
|
||||
->decltype(Op(matched_inst) \
|
||||
.WithOpcode(HloOpcode::kCompare) \
|
||||
.WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
|
||||
std::forward<Rhs>(rhs))) { \
|
||||
return Op(matched_inst) \
|
||||
.WithOpcode(HloOpcode::kCompare) \
|
||||
.WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
|
||||
std::forward<Rhs>(rhs)); \
|
||||
} \
|
||||
template <typename Lhs, typename Rhs> \
|
||||
inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
|
||||
->decltype(NAME##AnyOrder<const HloInstruction>( \
|
||||
nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \
|
||||
return NAME##AnyOrder<const HloInstruction>( \
|
||||
nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
|
||||
}
|
||||
|
||||
XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
|
||||
XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
|
||||
XLA_COMPARE_PATTERN(Ge);
|
||||
XLA_COMPARE_PATTERN(Gt);
|
||||
XLA_COMPARE_PATTERN(Le);
|
||||
XLA_COMPARE_PATTERN(Lt);
|
||||
|
||||
// Helpers for matching non-constant instructions.
|
||||
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
|
||||
return Op().IsNonConstant();
|
||||
|
@ -931,5 +931,48 @@ TEST(PatternMatcherTest, OneUseAndOneUser) {
|
||||
"in p0 = f32[] parameter(0)");
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, Comparison) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {1});
|
||||
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
|
||||
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
|
||||
auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||
ComparisonDirection::kEq);
|
||||
auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||
ComparisonDirection::kNe);
|
||||
auto add =
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
|
||||
auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
|
||||
ComparisonDirection::kLe);
|
||||
|
||||
EXPECT_TRUE(Match(eq.get(), m::Compare()));
|
||||
EXPECT_TRUE(Match(eq.get(), m::Eq()));
|
||||
EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
|
||||
EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
|
||||
EXPECT_TRUE(Match(ne.get(), m::Compare()));
|
||||
EXPECT_TRUE(Match(ne.get(), m::Ne()));
|
||||
EXPECT_TRUE(Match(
|
||||
le.get(),
|
||||
m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
|
||||
EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
|
||||
m::Add(m::Parameter(0), m::Parameter(1)))));
|
||||
|
||||
EXPECT_FALSE(Match(eq.get(), m::Add()));
|
||||
EXPECT_FALSE(Match(eq.get(), m::Ne()));
|
||||
EXPECT_FALSE(
|
||||
Match(le.get(),
|
||||
m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
|
||||
EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
eq, m::Ne().WithOneUser(),
|
||||
"an HloInstruction:\n"
|
||||
" * with opcode compare AND\n"
|
||||
" * which has comparison direction NE AND\n"
|
||||
" * which has exactly one user (but possibly is used "
|
||||
"multiple times by that instruction)",
|
||||
"HloInstruction is not comparison NE\n"
|
||||
"in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
|
||||
"direction=EQ");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -181,8 +181,9 @@ static StatusOr<HloInstruction*> CheckIndexValidity(
|
||||
HloInstruction* zero_index =
|
||||
BroadcastZeros(computation, index->shape().element_type(),
|
||||
AsInt64Slice(index->shape().dimensions()));
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check,
|
||||
MakeBinaryHlo(HloOpcode::kLe, zero_index, index));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * negative_index_check,
|
||||
MakeCompareHlo(ComparisonDirection::kLe, zero_index, index));
|
||||
|
||||
// Check if the index is OOB w.r.t. the operand dimensions and window sizes.
|
||||
std::vector<int64> max_valid_index(operand_dims.size());
|
||||
@ -193,9 +194,9 @@ static StatusOr<HloInstruction*> CheckIndexValidity(
|
||||
HloInstruction * max_valid_index_constant,
|
||||
MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
|
||||
max_valid_index));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * oob_index_check,
|
||||
MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index));
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check,
|
||||
MakeCompareHlo(ComparisonDirection::kGe,
|
||||
max_valid_index_constant, index));
|
||||
|
||||
// Combine the results of the two checks above.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
@ -988,12 +988,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kNe: {
|
||||
case HloOpcode::kCompare: {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
||||
InferElementwiseBinaryOpShape(opcode, lhs, rhs,
|
||||
broadcast_dimensions));
|
||||
|
@ -918,55 +918,10 @@ TEST_F(ShapeInferenceTest, InferPowShape) {
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
|
||||
TEST_F(ShapeInferenceTest, InferCompareShape) {
|
||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
||||
auto inferred_status =
|
||||
ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {});
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
||||
inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
|
||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
||||
auto inferred_status =
|
||||
ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {});
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
||||
inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
|
||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
||||
auto inferred_status =
|
||||
ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {});
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
||||
inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
|
||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
||||
auto inferred_status =
|
||||
ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {});
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
||||
inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
|
||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
||||
auto inferred_status =
|
||||
ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {});
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
||||
inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
|
||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
||||
auto inferred_status =
|
||||
ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {});
|
||||
auto inferred_status = ShapeInference::InferBinaryOpShape(
|
||||
HloOpcode::kCompare, ten_floats, f32_, {});
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
||||
inferred_status.ValueOrDie()));
|
||||
|
@ -39,7 +39,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -73,7 +73,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) {
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
p.2.lhs = u32[] parameter(4)
|
||||
p.2.rhs = u32[] parameter(5)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -109,7 +109,7 @@ TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -134,7 +134,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.1.lhs, p.1.rhs)
|
||||
ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
|
@ -180,13 +180,13 @@ StatusOr<HloInstruction*> StableSortExpander::ExpandInstruction(
|
||||
CHECK_NE(cloned_root, nullptr);
|
||||
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
|
||||
HloInstruction* same =
|
||||
comparator->AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_pred, HloOpcode::kEq, old_root, cloned_root));
|
||||
comparator->AddInstruction(HloInstruction::CreateCompare(
|
||||
scalar_pred, old_root, cloned_root, ComparisonDirection::kEq));
|
||||
HloInstruction* tie_breaker =
|
||||
comparator->AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_pred, HloOpcode::kLt,
|
||||
comparator->parameter_instruction(2 * iota_index),
|
||||
comparator->parameter_instruction(2 * iota_index + 1)));
|
||||
comparator->AddInstruction(HloInstruction::CreateCompare(
|
||||
scalar_pred, comparator->parameter_instruction(2 * iota_index),
|
||||
comparator->parameter_instruction(2 * iota_index + 1),
|
||||
ComparisonDirection::kLt));
|
||||
HloInstruction* new_root =
|
||||
comparator->AddInstruction(HloInstruction::CreateTernary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,
|
||||
|
@ -65,7 +65,8 @@ void CheckComputationHasTieBreaker(const HloInstruction* root,
|
||||
// the copied comparison function where the parameters are reversed. Lt() is
|
||||
// the tie breaker comparison using the Iota operand.
|
||||
ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
|
||||
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq);
|
||||
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare);
|
||||
ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq);
|
||||
|
||||
// Check that the tie breaker instruction is correct.
|
||||
EXPECT_THAT(root->operand(1),
|
||||
@ -88,7 +89,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -126,15 +127,15 @@ TEST_F(StableSortExpanderTest,
|
||||
lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
|
||||
lhs.flipped = u32[] subtract(max, lhs.unsigned)
|
||||
lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
|
||||
lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero)
|
||||
lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT
|
||||
lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
|
||||
rhs.signed = s32[] bitcast-convert(p.0.rhs)
|
||||
rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
|
||||
rhs.flipped = u32[] subtract(max, rhs.unsigned)
|
||||
rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
|
||||
rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero)
|
||||
rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT
|
||||
rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
|
||||
ROOT lt = pred[] less-than(lhs.converted, rhs.converted)
|
||||
ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -165,7 +166,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -200,7 +201,7 @@ TEST_F(StableSortExpanderTest, HonorIsStableFlag) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -227,7 +228,7 @@ TEST_F(StableSortExpanderTest,
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -264,7 +265,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = f32[] parameter(2)
|
||||
p.1.rhs = f32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -302,7 +303,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1) {
|
||||
mask = s32[] constant(65535)
|
||||
lhs = s32[] and(p.0.lhs, mask)
|
||||
rhs = s32[] and(p.0.rhs, mask)
|
||||
ROOT lt = pred[] less-than(lhs, rhs)
|
||||
ROOT lt = pred[] compare(lhs, rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
@ -332,7 +333,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) {
|
||||
mask = s32[] constant(65535)
|
||||
lhs = s32[] and(p.0.lhs, mask)
|
||||
rhs = s32[] and(p.0.rhs, mask)
|
||||
ROOT lt = pred[] less-than(lhs, rhs)
|
||||
ROOT lt = pred[] compare(lhs, rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY sort_computation {
|
||||
|
@ -934,8 +934,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
|
||||
HloInstruction::CreateParameter(0, in_shape, "param0"));
|
||||
auto param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, in_shape, "param1"));
|
||||
auto result = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
|
||||
auto result = builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
out_shape, param0, param1, ComparisonDirection::kEq));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
@ -1185,8 +1185,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
||||
auto data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
|
||||
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
|
||||
return builder.Build();
|
||||
};
|
||||
|
||||
|
@ -286,7 +286,7 @@ static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
|
||||
// Handle `i = K; i < N; ++i`.
|
||||
if (Match(while_cond_root,
|
||||
m::Op()
|
||||
.WithOpcode(HloOpcode::kLt)
|
||||
.WithComparisonDirection(ComparisonDirection::kLt)
|
||||
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
|
||||
VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
|
||||
<< while_cond_root->ToString();
|
||||
@ -303,7 +303,7 @@ static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
|
||||
// Handle `i = K; i <= N; ++i`.
|
||||
if (Match(while_cond_root,
|
||||
m::Op()
|
||||
.WithOpcode(HloOpcode::kLe)
|
||||
.WithComparisonDirection(ComparisonDirection::kLe)
|
||||
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
|
||||
VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
|
||||
<< while_cond_root->ToString();
|
||||
|
@ -40,7 +40,7 @@ TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) {
|
||||
p_cond = (f32[2], s32[]) parameter(0)
|
||||
gte = s32[] get-tuple-element(p_cond), index=1
|
||||
const = s32[] constant(42)
|
||||
ROOT result = pred[] equal-to(gte, const)
|
||||
ROOT result = pred[] compare(gte, const), direction=EQ
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
@ -71,7 +71,7 @@ TEST_F(WhileLoopAnalysisTest, NoUpperBound) {
|
||||
p_cond = (f32[2], s32[]) parameter(0)
|
||||
gte = s32[] get-tuple-element(p_cond), index=1
|
||||
const = s32[] constant(42)
|
||||
ROOT result = pred[] equal-to(gte, const)
|
||||
ROOT result = pred[] compare(gte, const), direction=EQ
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
@ -104,7 +104,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) {
|
||||
p_cond = (f32[2], s32[]) parameter(0)
|
||||
gte = s32[] get-tuple-element(p_cond), index=1
|
||||
const = s32[] constant(42)
|
||||
ROOT result = pred[] less-than(gte, const)
|
||||
ROOT result = pred[] compare(gte, const), direction=LT
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
|
@ -260,7 +260,7 @@ condition {
|
||||
p_cond = (f32[],f32[]) parameter(0)
|
||||
p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0
|
||||
p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1
|
||||
ROOT result = pred[] less-than(p_cond.0, p_cond.1)
|
||||
ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
@ -300,7 +300,7 @@ condition {
|
||||
p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0
|
||||
p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1
|
||||
p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1
|
||||
ROOT result = pred[] less-than(p_c.0, p_c.1.1)
|
||||
ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
@ -342,7 +342,7 @@ condition {
|
||||
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
|
||||
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
|
||||
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
||||
ROOT result = pred[] less-than(p_cond.0, p_cond.1)
|
||||
ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
@ -389,10 +389,10 @@ condition {
|
||||
p_cond = (f32[],f32[],f32[]) parameter(0)
|
||||
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
|
||||
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
||||
lt.0 = pred[] less-than(p_cond.0, p_cond.2)
|
||||
lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT
|
||||
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
|
||||
p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
||||
lt.1 = pred[] less-than(p_cond.1, p_cond.2.c)
|
||||
lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT
|
||||
ROOT result = pred[] and(lt.0, lt.1)
|
||||
}
|
||||
|
||||
|
@ -556,7 +556,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) {
|
||||
p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0)
|
||||
gte = s32[] get-tuple-element(p_cond), index=3
|
||||
const = s32[] constant(42)
|
||||
ROOT result = pred[] equal-to(gte, const)
|
||||
ROOT result = pred[] compare(gte, const), direction=EQ
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
|
@ -72,7 +72,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant({{LOOP_BOUND}})
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(42)
|
||||
@ -107,7 +107,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
|
||||
loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoopWithIndirectLoopBound {
|
||||
constant.3 = s32[] constant(42)
|
||||
@ -237,7 +237,7 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) {
|
||||
NonTupleShapedLoop.condition {
|
||||
loop_var = s32[] parameter(0)
|
||||
constant = s32[] constant(100)
|
||||
ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant)
|
||||
ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT
|
||||
}
|
||||
ENTRY INonTupleShapedLoop {
|
||||
constant.2 = s32[] constant(42)
|
||||
@ -387,7 +387,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
|
||||
param0 = (s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0),
|
||||
index=2
|
||||
ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element)
|
||||
ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ
|
||||
}
|
||||
ENTRY RemoveUnusedOperands {
|
||||
x = s32[] parameter(0)
|
||||
@ -471,7 +471,7 @@ TEST_F(WhileLoopSimplifierTest,
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(44)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(42)
|
||||
@ -503,7 +503,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) {
|
||||
loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
|
||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(47)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(42)
|
||||
@ -679,7 +679,7 @@ const char* const kSimpleMergeInductionVariablesModule = R"(
|
||||
b = TYPE[] get-tuple-element(param), index=1
|
||||
sum = TYPE[] power(a, b)
|
||||
ten = TYPE[] constant(10)
|
||||
ROOT cond = pred[] less-than(sum, ten)
|
||||
ROOT cond = pred[] compare(sum, ten), direction=LT
|
||||
}
|
||||
ENTRY Loop {
|
||||
a = TYPE[] constant(10)
|
||||
|
@ -41,7 +41,7 @@ TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) {
|
||||
param = (s32[]) parameter(0)
|
||||
i = s32[] get-tuple-element(param), index=0
|
||||
trip_count = s32[] constant(10)
|
||||
ROOT done = pred[] less-than(i, trip_count)
|
||||
ROOT done = pred[] compare(i, trip_count), direction=LT
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
@ -77,7 +77,7 @@ TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) {
|
||||
param = (s32[]) parameter(0)
|
||||
i = s32[] get-tuple-element(param), index=0
|
||||
trip_count = s32[] constant(1000000)
|
||||
ROOT done = pred[] less-than(i, trip_count)
|
||||
ROOT done = pred[] compare(i, trip_count), direction=LT
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
@ -113,7 +113,7 @@ TEST_F(TripCountAnnotatorTest, NonzeroStart) {
|
||||
param = (s32[]) parameter(0)
|
||||
i = s32[] get-tuple-element(param), index=0
|
||||
trip_count = s32[] constant(1000000)
|
||||
ROOT done = pred[] less-than(i, trip_count)
|
||||
ROOT done = pred[] compare(i, trip_count), direction=LT
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
@ -149,7 +149,7 @@ TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) {
|
||||
param = (s32[]) parameter(0)
|
||||
i = s32[] get-tuple-element(param), index=0
|
||||
trip_count = s32[] constant(1000000)
|
||||
ROOT done = pred[] less-than-or-equal-to(i, trip_count)
|
||||
ROOT done = pred[] compare(i, trip_count), direction=LE
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
@ -188,7 +188,7 @@ TEST_F(TripCountAnnotatorTest, Int64Overflow) {
|
||||
param = (s64[]) parameter(0)
|
||||
i = s64[] get-tuple-element(param), index=0
|
||||
trip_count = s64[] constant(9223372036854775807) // 2^63-1
|
||||
ROOT done = pred[] less-than-or-equal-to(i, trip_count)
|
||||
ROOT done = pred[] compare(i, trip_count), direction=LE
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
|
@ -166,7 +166,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * compare,
|
||||
MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant));
|
||||
MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
|
||||
cond_computation->set_root_instruction(compare);
|
||||
return std::move(cond_computation);
|
||||
}
|
||||
|
@ -63,7 +63,11 @@ const float test_float_vals[3][test_width][test_height] = {
|
||||
class FusionTest : public HloTestBase {
|
||||
protected:
|
||||
template <typename T, int Arity>
|
||||
void TestElementwise2D(HloOpcode opcode) {
|
||||
void TestElementwise2D(
|
||||
HloOpcode opcode,
|
||||
absl::optional<ComparisonDirection> direction = absl::nullopt) {
|
||||
// Create a variable for comparisons since they require the direction.
|
||||
bool is_compare = std::is_same<T, bool>::value;
|
||||
Array2D<float> operand_data[Arity];
|
||||
for (int i = 0; i < Arity; ++i) {
|
||||
new (&operand_data[i]) Array2D<float>(test_width, test_height);
|
||||
@ -76,7 +80,11 @@ class FusionTest : public HloTestBase {
|
||||
xs[k] = test_float_vals[k][i][j];
|
||||
operand_data[k](i, j) = xs[k];
|
||||
}
|
||||
answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
|
||||
if (is_compare) {
|
||||
answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs);
|
||||
} else {
|
||||
answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,8 +106,13 @@ class FusionTest : public HloTestBase {
|
||||
root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
|
||||
break;
|
||||
case 2:
|
||||
root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
|
||||
hlos[2]);
|
||||
if (is_compare) {
|
||||
root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1],
|
||||
hlos[2], *direction);
|
||||
} else {
|
||||
root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
|
||||
hlos[2]);
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
|
||||
@ -124,13 +137,14 @@ class FusionTest : public HloTestBase {
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span<const float> xs);
|
||||
float ComputeElementwiseAnswerFloat(HloOpcode opcode,
|
||||
absl::Span<const float> xs);
|
||||
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
||||
absl::Span<const float> xs);
|
||||
};
|
||||
|
||||
template <>
|
||||
float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
|
||||
absl::Span<const float> xs) {
|
||||
float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode,
|
||||
absl::Span<const float> xs) {
|
||||
switch (opcode) {
|
||||
case HloOpcode::kAdd:
|
||||
return xs[0] + xs[1];
|
||||
@ -153,24 +167,21 @@ float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode,
|
||||
absl::Span<const float> xs) {
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
||||
absl::Span<const float> xs) {
|
||||
switch (direction) {
|
||||
case ComparisonDirection::kEq:
|
||||
return xs[0] == xs[1];
|
||||
case HloOpcode::kNe:
|
||||
case ComparisonDirection::kNe:
|
||||
return xs[0] != xs[1];
|
||||
case HloOpcode::kGt:
|
||||
case ComparisonDirection::kGt:
|
||||
return xs[0] > xs[1];
|
||||
case HloOpcode::kLt:
|
||||
case ComparisonDirection::kLt:
|
||||
return xs[0] < xs[1];
|
||||
case HloOpcode::kGe:
|
||||
case ComparisonDirection::kGe:
|
||||
return xs[0] >= xs[1];
|
||||
case HloOpcode::kLe:
|
||||
case ComparisonDirection::kLe:
|
||||
return xs[0] <= xs[1];
|
||||
default:
|
||||
LOG(FATAL) << "No comparatory opcode: " << opcode;
|
||||
}
|
||||
}
|
||||
|
||||
@ -740,24 +751,28 @@ XLA_TEST_F(FusionTest, Maximum2D) {
|
||||
TestElementwise2D<float, 2>(HloOpcode::kMaximum);
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<bool, 2>(HloOpcode::kEq); }
|
||||
XLA_TEST_F(FusionTest, Equal2D) {
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kEq);
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, Inequal2D) {
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kNe);
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kNe);
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, Greater2D) {
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kGt);
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGt);
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D<bool, 2>(HloOpcode::kLt); }
|
||||
XLA_TEST_F(FusionTest, Lesser2D) {
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLt);
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kGe);
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGe);
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, LesserOrEqual2D) {
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kLe);
|
||||
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLe);
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, Clamp2D) {
|
||||
|
@ -227,7 +227,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
|
||||
fused_computation {
|
||||
p = f32[4] parameter(0)
|
||||
multiply = f32[4] multiply(p, p)
|
||||
less-than = pred[4] less-than(p, multiply)
|
||||
less-than = pred[4] compare(p, multiply), direction=LT
|
||||
ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply)
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
|
||||
fused_computation {
|
||||
p = f32[] parameter(0)
|
||||
multiply = f32[] multiply(p, p)
|
||||
less-than = pred[] less-than(p, multiply)
|
||||
less-than = pred[] compare(p, multiply), direction=LT
|
||||
ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
|
||||
}
|
||||
|
||||
|
@ -143,7 +143,7 @@ compare {
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) {
|
||||
@ -174,7 +174,7 @@ compare {
|
||||
p.0.rhs = s32[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) {
|
||||
@ -205,7 +205,7 @@ compare {
|
||||
p.0.rhs = bf16[] parameter(1)
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) {
|
||||
|
@ -129,7 +129,7 @@ HloModule TokenInWhileLoop
|
||||
%param = (s32[], token[]) parameter(0)
|
||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||
%constant = s32[] constant(42)
|
||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
||||
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %TokenInWhileLoop () -> s32[] {
|
||||
|
Loading…
Reference in New Issue
Block a user