From 5ca8c1bcb9c1a0d6f411269fc2300cd7e2d467a5 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Fri, 17 Jul 2020 11:49:21 -0700 Subject: [PATCH] Canonicalize dense array conditional into tuple conditional with one element. It's annoying do deal with the fact that conditional can be either tuple or non-tuple. Canonicalize everything into tuple. PiperOrigin-RevId: 321823386 Change-Id: I8bfe798bd1b4af9c3ffd169fa6b497c8b2f92b4a --- tensorflow/compiler/xla/service/BUILD | 33 +++++++++ .../xla/service/conditional_canonicalizer.cc | 60 ++++++++++++++++ .../xla/service/conditional_canonicalizer.h | 38 ++++++++++ .../service/conditional_canonicalizer_test.cc | 72 +++++++++++++++++++ tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../compiler/xla/service/gpu/gpu_compiler.cc | 3 +- 8 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/xla/service/conditional_canonicalizer.cc create mode 100644 tensorflow/compiler/xla/service/conditional_canonicalizer.h create mode 100644 tensorflow/compiler/xla/service/conditional_canonicalizer_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 52cba3837dd..8d267affdd9 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3932,6 +3932,39 @@ tf_cc_test( ], ) +cc_library( + name = "conditional_canonicalizer", + srcs = ["conditional_canonicalizer.cc"], + hdrs = ["conditional_canonicalizer.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + ], +) + +tf_cc_test( + name = "conditional_canonicalizer_test", + srcs = ["conditional_canonicalizer_test.cc"], + deps = [ + ":conditional_canonicalizer", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_get_dimension_size_rewriter", srcs = ["hlo_get_dimension_size_rewriter.cc"], diff --git a/tensorflow/compiler/xla/service/conditional_canonicalizer.cc b/tensorflow/compiler/xla/service/conditional_canonicalizer.cc new file mode 100644 index 00000000000..3d917eb39fe --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_canonicalizer.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace { +Status CanonicalizeNonTupleConditional(HloInstruction* conditional) { + TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); + for (auto* branch : conditional->called_computations()) { + HloInstruction* root = branch->root_instruction(); + TF_RET_CHECK(!root->shape().IsTuple()); + + HloInstruction* tuple = + branch->AddInstruction(HloInstruction::CreateTuple({root})); + branch->set_root_instruction(tuple, /*accept_different_shape=*/true); + } + auto root_shape = conditional->shape(); + *conditional->mutable_shape() = ShapeUtil::MakeTupleShape({root_shape}); + auto gte = conditional->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(root_shape, conditional, 0)); + TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWithDifferentShape(gte)); + return Status::OK(); +} +} // namespace + +StatusOr ConditionalCanonicalizer::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "ConditionalCanonicalizer::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kConditional && + !inst->shape().IsTuple()) { + TF_RETURN_IF_ERROR(CanonicalizeNonTupleConditional(inst)); + changed = true; + } + } + } + XLA_VLOG_LINES( + 2, "ConditionalCanonicalizer::Run(), after:\n" + module->ToString()); + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_canonicalizer.h b/tensorflow/compiler/xla/service/conditional_canonicalizer.h new file mode 100644 index 00000000000..a390d87a007 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_canonicalizer.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Canonicalize output of conditionals, make non-tuple outputs into tuple with +// single element output. After this pass, all conditional instructions have +// tuple outputs. +class ConditionalCanonicalizer : public HloModulePass { + public: + absl::string_view name() const override { + return "conditional canonicalizer"; + } + + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_canonicalizer_test.cc b/tensorflow/compiler/xla/service/conditional_canonicalizer_test.cc new file mode 100644 index 00000000000..498260cbabf --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_canonicalizer_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ConditionalCanonicalizerTest : public HloTestBase { + protected: + ConditionalCanonicalizerTest() {} +}; + +TEST_F(ConditionalCanonicalizerTest, DenseArrayConditionalRewrite) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +true_branch { + true_param = (s32[3,2]) parameter(0) + ROOT root = s32[] constant(0) +} + +false_branch { + false_param = (s32[3,2]) parameter(0) + ROOT root = s32[] constant(1) +} + +ENTRY entry { + param0 = s32[3,2] parameter(0) + branch = pred[] constant(false) + param_tuple = (s32[3 ,2]) tuple(param0) + ROOT conditional = s32[] conditional(branch, param_tuple, param_tuple), + true_computation=true_branch, false_computation=false_branch +} +)") + .ValueOrDie(); + ConditionalCanonicalizer pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::GetTupleElement(op::Conditional())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ac167b00bb3..782d08296f0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -138,6 +138,7 @@ cc_library( "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:tree_reduction_rewriter", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", + "//tensorflow/compiler/xla/service:conditional_canonicalizer", "//tensorflow/compiler/xla/service:conditional_to_select", "//tensorflow/compiler/xla/service:slow_operation_alarm", "//tensorflow/compiler/xla/service:scatter_expander", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5464cfee082..04d703fdd59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/conditional_to_select.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" @@ -284,6 +285,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_grad_op=*/true); pipeline.AddPass( /*expansion_type=*/LogisticExpansionType::kExp); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b22f258bac6..7b1d3e213ce 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1168,6 +1168,7 @@ cc_library( "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:conditional_canonicalizer", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_4d_expander", "//tensorflow/compiler/xla/service:dot_decomposer", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3050e794f10..f2d29b5d11f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_4d_expander.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" @@ -179,7 +180,7 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass( /*expansion_type=*/LogisticExpansionType::kExp); - + pipeline.AddPass(); pipeline.AddPass(); {