[XLA:GPU] Split reduce ops with large but non-consecutive reduction dimensions.

PiperOrigin-RevId: 317330616
Change-Id: Icdcf320b233479c2f74c5b40ee4c8d9a73a6088a
This commit is contained in:
Thomas Joerg 2020-06-19 10:18:32 -07:00 committed by TensorFlower Gardener
parent 57f9d638c0
commit ef1cabc7a8
7 changed files with 338 additions and 0 deletions

View File

@ -1174,6 +1174,7 @@ cc_library(
":reduction_degenerate_dim_remover",
":reduction_dimension_grouper",
":reduction_layout_normalizer",
":reduction_splitter",
":stream_assignment",
":stream_executor_util",
":target_constants",
@ -1819,6 +1820,33 @@ cc_library(
],
)
cc_library(
name = "reduction_splitter",
srcs = ["reduction_splitter.cc"],
hdrs = ["reduction_splitter.h"],
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
],
)
tf_cc_test(
name = "reduction_splitter_test",
srcs = ["reduction_splitter_test.cc"],
deps = [
":reduction_splitter",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "reduction_layout_normalizer",
srcs = ["reduction_layout_normalizer.cc"],

View File

@ -65,6 +65,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
#include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
@ -371,6 +372,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
pipeline.AddPass<ReductionDegenerateDimRemover>();
pipeline.AddPass<ReductionLayoutNormalizer>();
pipeline.AddPass<ReductionDimensionGrouper>();
pipeline.AddPass<HloPassFix<ReductionSplitter>>();
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.

View File

@ -0,0 +1,117 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
#include <algorithm>
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
namespace gpu {
class ReductionSplitterVisitor : public DfsHloRewriteVisitor {
public:
Status HandleReduce(HloInstruction *reduce) override {
VLOG(4) << "Input: " << reduce->ToString();
// Reductions with contiguous dimensions are lowered to efficient code. No
// need to split such ops.
if (IsReductionFromOrToContiguousDimensions(*reduce)) {
return Status::OK();
}
if (reduce->dimensions().size() < 2) {
return Status::OK();
}
if (!reduce->shape().IsArray()) {
// TODO(cheshire): Handle variadic reduction.
return Status::OK();
}
HloInstruction *operand = reduce->mutable_operand(0);
const Shape &shape = operand->shape();
CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
<< "Default layout should be enforced on reduction operand";
// Verify that contiguous dimensions have been grouped by the
// ReductionDimensionGrouper pass.
for (int64 i = 0; i < reduce->dimensions().size(); ++i) {
for (int64 j = i + 1; j < reduce->dimensions().size(); ++j) {
CHECK(abs(reduce->dimensions(i) - reduce->dimensions(j)) > 1)
<< "Reduction dimensions must not be consecutive";
}
}
// The reduce op has non-contiguous dimensions. Look for the dimension with
// the largest shape dimension. Reducing along this dimension first will
// reduce the output size most effectively.
int64 max_shape_dim = 0;
int64 max_reduce_dim = 0;
const auto &input_shape = reduce->operand(0)->shape();
for (int64 i = 0; i < reduce->dimensions().size(); ++i) {
if (input_shape.dimensions(reduce->dimensions(i)) > max_shape_dim) {
max_reduce_dim = reduce->dimensions(i);
max_shape_dim = input_shape.dimensions(max_reduce_dim);
}
}
// TODO(tjoerg): Run microbenchmarks to tune this threshold.
if (max_shape_dim < 128) {
return Status::OK();
}
// Split the reduction into a pre-reduction and a final reduction.
VLOG(3) << "Splitting reduction " << reduce->name() << " at dimension "
<< max_reduce_dim;
std::vector<int64> pre_reduce_dims;
pre_reduce_dims.push_back(max_reduce_dim);
std::vector<int64> pre_reduce_shape_dims(input_shape.dimensions().begin(),
input_shape.dimensions().end());
pre_reduce_shape_dims.erase(pre_reduce_shape_dims.begin() + max_reduce_dim);
Shape pre_reduce_shape = ShapeUtil::MakeShape(
reduce->shape().element_type(), pre_reduce_shape_dims);
std::unique_ptr<HloInstruction> pre_reduce = HloInstruction::CreateReduce(
pre_reduce_shape, reduce->mutable_operand(0),
reduce->mutable_operand(1), pre_reduce_dims, reduce->to_apply());
pre_reduce->set_metadata(reduce->metadata());
std::vector<int64> final_reduce_dims(reduce->dimensions().begin(),
reduce->dimensions().end());
final_reduce_dims.erase(
std::remove(final_reduce_dims.begin(), final_reduce_dims.end(),
max_reduce_dim),
final_reduce_dims.end());
for (int64 i = 0; i < final_reduce_dims.size(); ++i) {
if (final_reduce_dims[i] > max_reduce_dim) {
final_reduce_dims[i]--;
}
}
std::unique_ptr<HloInstruction> final_reduce = HloInstruction::CreateReduce(
reduce->shape(),
reduce->parent()->AddInstruction(std::move(pre_reduce)),
reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply());
return ReplaceWithNewInstruction(reduce, std::move(final_reduce));
}
};
StatusOr<bool> ReductionSplitter::Run(HloModule *module) {
TF_ASSIGN_OR_RETURN(bool changed,
ReductionSplitterVisitor().RunOnModule(module));
return changed;
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace gpu {
// Splits a reduce op into two consecutive reduce ops if
// * the reduce dimensions are not contiguous and
// * at least one reduce dimension is large (i.e. corresponds to a large input
// shape dimension).
//
// Reductions with non-contiguous dimensions are emitted as simple element-wise
// loops. This is inefficient when reducing large input shape dimensions.
// Splitting such reductions allows using more efficient reduction emitters.
//
// This pass splits reduce ops into two consecutive reduce ops. Run it to a
// fixpoint to split reduce ops along multiple large dimensions.
//
// Precondition: ReductionDimensionGrouper has been run and adjacent reduce
// dimentsions have been grouped. Reduction layouts have been normalized.
class ReductionSplitter : public HloModulePass {
public:
absl::string_view name() const override { return "reduction-splitter"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_

View File

@ -0,0 +1,140 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace gpu {
namespace {
namespace op = xla::testing::opcode_matchers;
class ReductionSplitterTest : public HloTestBase {};
TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test
add_computation {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry_computation {
param_0 = f16[6,16,512,64]{3,2,1,0} parameter(0)
transpose.1781 = f16[6,512,16,64]{3,1,2,0} transpose(param_0), dimensions={0,2,1,3}
convert.6986 = f32[6,512,16,64]{3,1,2,0} convert(transpose.1781)
bitcast.2136 = f32[6,16,512,64]{3,2,1,0} bitcast(convert.6986)
constant_11111 = f32[] constant(0)
ROOT reduce.982 = f32[16,64]{1,0} reduce(bitcast.2136, constant_11111), dimensions={0,2}, to_apply=add_computation
}
)")
.ValueOrDie();
ASSERT_TRUE(ReductionSplitter().Run(module.get()).ValueOrDie());
SCOPED_TRACE(module->ToString());
const HloInstruction* root_reduction =
module->entry_computation()->root_instruction();
ASSERT_THAT(root_reduction, op::Reduce(op::Reduce(), op::Constant()));
auto* pre_reduction = root_reduction->operand(0);
EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64>({2}));
EXPECT_THAT(pre_reduction->shape(), ShapeUtil::MakeShape(F32, {6, 16, 64}));
EXPECT_THAT(root_reduction->dimensions(), std::vector<int64>({0}));
EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
}
TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test
add_computation {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry_computation {
param_0 = f32[1024,16,512,64,128]{4,3,2,1,0} parameter(0)
constant_11111 = f32[] constant(0)
ROOT reduce.982 = f32[16,64]{1,0} reduce(param_0, constant_11111), dimensions={2,0,4}, to_apply=add_computation
}
)")
.ValueOrDie();
ASSERT_TRUE(ReductionSplitter().Run(module.get()).ValueOrDie());
SCOPED_TRACE(module->ToString());
const HloInstruction* root_reduction =
module->entry_computation()->root_instruction();
ASSERT_THAT(root_reduction, op::Reduce(op::Reduce(), op::Constant()));
auto* pre_reduction = root_reduction->operand(0);
EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64>({0}));
EXPECT_THAT(pre_reduction->shape(),
ShapeUtil::MakeShape(F32, {16, 512, 64, 128}));
EXPECT_THAT(root_reduction->dimensions(), std::vector<int64>({1, 3}));
EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
}
TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test
add_computation {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry_computation {
param_0 = f32[8,1024,8]{2,1,0} parameter(0)
constant_11111 = f32[] constant(0)
ROOT reduce.982 = f32[1024]{0} reduce(param_0, constant_11111), dimensions={2,0}, to_apply=add_computation
}
)")
.ValueOrDie();
EXPECT_FALSE(ReductionSplitter().Run(module.get()).ValueOrDie());
}
TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test
add_computation {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry_computation {
param_0 = f32[128,128,64,128]{3,2,1,0} parameter(0)
constant_11111 = f32[] constant(0)
// The dimenstions to keep (1 and 2) are contiguous.
ROOT reduce.982 = f32[128,64]{1,0} reduce(param_0, constant_11111), dimensions={3,0}, to_apply=add_computation
}
)")
.ValueOrDie();
EXPECT_FALSE(ReductionSplitter().Run(module.get()).ValueOrDie());
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -37,6 +37,7 @@ class ReductionDegenerateDimRemoverTest : public GpuCodegenTest {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer");
debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper");
debug_options.add_xla_disable_hlo_passes("reduction-splitter");
debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter");
return debug_options;
}

View File

@ -33,6 +33,7 @@ class ReductionLayoutNormalizerTest : public GpuCodegenTest {
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper");
debug_options.add_xla_disable_hlo_passes("reduction-splitter");
debug_options.add_xla_disable_hlo_passes("layout-assignment");
debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter");
return debug_options;