[XLA] Add token support to conditional_simplifier

We tried to use kSelect to choose between two tokens. kSelect does not support TOKEN element types.

Instead, use kAfterAll to select between the tokens. This is OK because tokens don't actually have a value associated with them, they just try to enforce an ordering.

PiperOrigin-RevId: 314360456
Change-Id: Ia13cc66113c6cf6f1989848eb39d1b90c8674988
This commit is contained in:
David Majnemer 2020-06-02 10:44:06 -07:00 committed by TensorFlower Gardener
parent 41aea74ea2
commit f653ab8bb3
3 changed files with 40 additions and 0 deletions

View File

@ -2212,6 +2212,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/stream_executor/lib",
],
)

View File

@ -142,6 +142,10 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
};
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
[&](HloInstruction* t, HloInstruction* f) {
if (f->shape().IsToken()) {
return computation->AddInstruction(
HloInstruction::CreateAfterAll({t, f}));
}
if (f->shape().IsArray()) {
return computation->AddInstruction(HloInstruction::CreateTernary(
f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
namespace {
@ -482,6 +483,40 @@ ENTRY main {
EXPECT_EQ(gte_1->tuple_index(), 0);
}
// Since select can only be used on arrays, use after-all for token types.
TEST_F(ConditionalSimplifierTest, SimplifyConditionalWithTokens) {
absl::string_view hlo_string =
R"(
HloModule SimplifyConditionalWithTokens
true_comp {
ROOT parameter.13 = (token[]) parameter(0)
}
false_comp {
ROOT parameter.21 = (token[]) parameter(0)
}
ENTRY entry {
parameter.29 = pred[] parameter(0)
token.1 = token[] after-all()
token.2 = token[] after-all()
tuple.3 = (token[]) tuple(token.1)
tuple.4 = (token[]) tuple(token.2)
ROOT conditional.5 = (token[]) conditional(parameter.29, tuple.3, tuple.4), true_computation=true_comp, false_computation=false_comp
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
TF_ASSERT_OK(v.Run(module.get()).status());
EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::AfterAll(
op::GetTupleElement(op::Tuple(op::AfterAll()), 0),
op::GetTupleElement(op::Tuple(op::AfterAll()), 0))));
}
} // namespace
} // namespace xla