Canonicalize dense array conditional into tuple conditional with one element.
It's annoying do deal with the fact that conditional can be either tuple or non-tuple. Canonicalize everything into tuple. PiperOrigin-RevId: 321823386 Change-Id: I8bfe798bd1b4af9c3ffd169fa6b497c8b2f92b4a
This commit is contained in:
parent
fdf1095dcd
commit
5ca8c1bcb9
@ -3932,6 +3932,39 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conditional_canonicalizer",
|
||||
srcs = ["conditional_canonicalizer.cc"],
|
||||
hdrs = ["conditional_canonicalizer.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "conditional_canonicalizer_test",
|
||||
srcs = ["conditional_canonicalizer_test.cc"],
|
||||
deps = [
|
||||
":conditional_canonicalizer",
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_get_dimension_size_rewriter",
|
||||
srcs = ["hlo_get_dimension_size_rewriter.cc"],
|
||||
|
||||
60
tensorflow/compiler/xla/service/conditional_canonicalizer.cc
Normal file
60
tensorflow/compiler/xla/service/conditional_canonicalizer.cc
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
Status CanonicalizeNonTupleConditional(HloInstruction* conditional) {
|
||||
TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
|
||||
for (auto* branch : conditional->called_computations()) {
|
||||
HloInstruction* root = branch->root_instruction();
|
||||
TF_RET_CHECK(!root->shape().IsTuple());
|
||||
|
||||
HloInstruction* tuple =
|
||||
branch->AddInstruction(HloInstruction::CreateTuple({root}));
|
||||
branch->set_root_instruction(tuple, /*accept_different_shape=*/true);
|
||||
}
|
||||
auto root_shape = conditional->shape();
|
||||
*conditional->mutable_shape() = ShapeUtil::MakeTupleShape({root_shape});
|
||||
auto gte = conditional->parent()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(root_shape, conditional, 0));
|
||||
TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWithDifferentShape(gte));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> ConditionalCanonicalizer::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(
|
||||
2, "ConditionalCanonicalizer::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
for (auto* inst : comp->MakeInstructionPostOrder()) {
|
||||
if (inst->opcode() == HloOpcode::kConditional &&
|
||||
!inst->shape().IsTuple()) {
|
||||
TF_RETURN_IF_ERROR(CanonicalizeNonTupleConditional(inst));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(
|
||||
2, "ConditionalCanonicalizer::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
} // namespace xla
|
||||
38
tensorflow/compiler/xla/service/conditional_canonicalizer.h
Normal file
38
tensorflow/compiler/xla/service/conditional_canonicalizer.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Canonicalize output of conditionals, make non-tuple outputs into tuple with
|
||||
// single element output. After this pass, all conditional instructions have
|
||||
// tuple outputs.
|
||||
class ConditionalCanonicalizer : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override {
|
||||
return "conditional canonicalizer";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_
|
||||
@ -0,0 +1,72 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class ConditionalCanonicalizerTest : public HloTestBase {
|
||||
protected:
|
||||
ConditionalCanonicalizerTest() {}
|
||||
};
|
||||
|
||||
TEST_F(ConditionalCanonicalizerTest, DenseArrayConditionalRewrite) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule _
|
||||
true_branch {
|
||||
true_param = (s32[3,2]) parameter(0)
|
||||
ROOT root = s32[] constant(0)
|
||||
}
|
||||
|
||||
false_branch {
|
||||
false_param = (s32[3,2]) parameter(0)
|
||||
ROOT root = s32[] constant(1)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = s32[3,2] parameter(0)
|
||||
branch = pred[] constant(false)
|
||||
param_tuple = (s32[3 ,2]) tuple(param0)
|
||||
ROOT conditional = s32[] conditional(branch, param_tuple, param_tuple),
|
||||
true_computation=true_branch, false_computation=false_branch
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
ConditionalCanonicalizer pass;
|
||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::GetTupleElement(op::Conditional()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
@ -138,6 +138,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
|
||||
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
|
||||
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
||||
"//tensorflow/compiler/xla/service:conditional_to_select",
|
||||
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
||||
"//tensorflow/compiler/xla/service:scatter_expander",
|
||||
|
||||
@ -54,6 +54,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||
@ -284,6 +285,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
/*rewrite_grad_op=*/true);
|
||||
pipeline.AddPass<LogisticExpander>(
|
||||
/*expansion_type=*/LogisticExpansionType::kExp);
|
||||
pipeline.AddPass<ConditionalCanonicalizer>();
|
||||
pipeline.AddPass<DynamicPadder>();
|
||||
pipeline.AddPass<ScatterExpander>();
|
||||
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
||||
|
||||
@ -1168,6 +1168,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_4d_expander",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
|
||||
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
@ -179,7 +180,7 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
|
||||
pipeline.AddPass<LogisticExpander>(
|
||||
/*expansion_type=*/LogisticExpansionType::kExp);
|
||||
|
||||
pipeline.AddPass<ConditionalCanonicalizer>();
|
||||
pipeline.AddPass<DynamicPadder>();
|
||||
|
||||
{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user