[XLA] Add token support to conditional_simplifier
We tried to use kSelect to choose between two tokens. kSelect does not support TOKEN element types. Instead, use kAfterAll to select between the tokens. This is OK because tokens don't actually have a value associated with them, they just try to enforce an ordering. PiperOrigin-RevId: 314360456 Change-Id: Ia13cc66113c6cf6f1989848eb39d1b90c8674988
This commit is contained in:
parent
41aea74ea2
commit
f653ab8bb3
@ -2212,6 +2212,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -142,6 +142,10 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
||||
};
|
||||
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
|
||||
[&](HloInstruction* t, HloInstruction* f) {
|
||||
if (f->shape().IsToken()) {
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateAfterAll({t, f}));
|
||||
}
|
||||
if (f->shape().IsArray()) {
|
||||
return computation->AddInstruction(HloInstruction::CreateTernary(
|
||||
f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -482,6 +483,40 @@ ENTRY main {
|
||||
EXPECT_EQ(gte_1->tuple_index(), 0);
|
||||
}
|
||||
|
||||
// Since select can only be used on arrays, use after-all for token types.
|
||||
TEST_F(ConditionalSimplifierTest, SimplifyConditionalWithTokens) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule SimplifyConditionalWithTokens
|
||||
|
||||
true_comp {
|
||||
ROOT parameter.13 = (token[]) parameter(0)
|
||||
}
|
||||
|
||||
false_comp {
|
||||
ROOT parameter.21 = (token[]) parameter(0)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
parameter.29 = pred[] parameter(0)
|
||||
token.1 = token[] after-all()
|
||||
token.2 = token[] after-all()
|
||||
tuple.3 = (token[]) tuple(token.1)
|
||||
tuple.4 = (token[]) tuple(token.2)
|
||||
ROOT conditional.5 = (token[]) conditional(parameter.29, tuple.3, tuple.4), true_computation=true_comp, false_computation=false_comp
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
|
||||
TF_ASSERT_OK(v.Run(module.get()).status());
|
||||
EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AfterAll(
|
||||
op::GetTupleElement(op::Tuple(op::AfterAll()), 0),
|
||||
op::GetTupleElement(op::Tuple(op::AfterAll()), 0))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user