[XLA] Move sharding propagation to third party

This also moves some utilities of interpreting convolutions as dots.

PiperOrigin-RevId: 312868839
Change-Id: I90bdc30217edf6dfb301a9c80b7155653391fa1a
This commit is contained in:
Yuanzhong Xu 2020-05-22 18:14:56 -07:00 committed by TensorFlower Gardener
parent 18aaa18cf1
commit f460141434
8 changed files with 3166 additions and 0 deletions

View File

@ -491,6 +491,66 @@ tf_cc_test(
],
)
cc_library(
name = "sharding_propagation",
srcs = [
"sharding_propagation.cc",
],
hdrs = [
"sharding_propagation.h",
],
deps = [
":dot_as_convolution_util",
":hlo",
":hlo_graph_dumper",
":hlo_pass",
":hlo_sharding_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "sharding_propagation_test",
srcs = [
"sharding_propagation_test.cc",
],
deps = [
"hlo_matchers",
":hlo_parser",
":sharding_propagation",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "dot_as_convolution_util",
srcs = [
"dot_as_convolution_util.cc",
],
hdrs = [
"dot_as_convolution_util.h",
],
deps = [
":hlo",
":shape_inference",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "dynamic_parameter_binding_test",
srcs = ["dynamic_parameter_binding_test.cc"],

View File

@ -0,0 +1,139 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
namespace dot_as_convolution_util {
/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
ParseDotGeneralFromConvolution(const HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
return absl::nullopt;
}
const auto& conv_dims = conv->convolution_dimension_numbers();
DotGeneralAsConvolutionDimsInfo dims;
dims.lhs_non_contracting_dims.push_back(
{conv_dims.input_batch_dimension(), -1,
conv_dims.output_batch_dimension(), -1});
dims.rhs_non_contracting_dims.push_back(
{-1, conv_dims.kernel_output_feature_dimension(),
conv_dims.output_feature_dimension(), -1});
dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
conv_dims.kernel_input_feature_dimension(),
-1, -1});
for (int64 i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
int64 lhs = conv_dims.input_spatial_dimensions(i);
int64 lhs_size = conv->operand(0)->shape().dimensions(lhs);
int64 rhs = conv_dims.kernel_spatial_dimensions(i);
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
int64 output = conv_dims.output_spatial_dimensions(i);
const auto& wd = conv->window().dimensions(i);
if (lhs_size == wd.size() &&
std::max<int64>(1, lhs_size - 1) == wd.stride() &&
lhs_size == wd.base_dilation() && wd.window_dilation() == 1 &&
wd.padding_high() == 0 && wd.padding_low() == 0 &&
!wd.window_reversal()) {
// A batch dimension in DotGeneral is represented as a spatial dimension
// with window size B (batch dimension size), stride B - 1, and base
// dilation B.
dims.batch_dims.push_back({lhs, rhs, output, i});
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
wd.window_dilation() == 1 && wd.padding_high() == 0 &&
wd.padding_low() == 0 && !wd.window_reversal()) {
// A contracting dimension be represented as a spatial dimension with
// window size C (contracting dimension size). Stride can be any size
// since there is only one window.
dims.contracting_dims.push_back({lhs, rhs, output, i});
} else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
wd.base_dilation() == 1) {
if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
wd.padding_low() == 0 && !wd.window_reversal()) {
// A LHS non-contracting dimension can be represented as a spatial
// dimension with window size 1.
dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
} else if (lhs_size == 1 && wd.size() == rhs_size &&
wd.padding_high() == rhs_size - 1 &&
wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
// A RHS non-contracting dimension can be represented as a spatial
// dimension with window size N (non-contracting dimension size), low
// padding N - 1, high padding N - 1 and window reversal.
dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
} else {
return absl::nullopt;
}
} else {
return absl::nullopt;
}
}
return dims;
}
StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(
const HloInstruction& conv,
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
const auto& conv_dnums = conv.convolution_dimension_numbers();
auto window = conv.window();
for (const auto& dim : dot_dnums.batch_dims) {
auto wd = window.mutable_dimensions(dim.spatial_dim);
wd->set_size(sharded_lhs_hlo->shape().dimensions(
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
for (const auto& dim : dot_dnums.contracting_dims) {
if (dim.spatial_dim < 0) {
continue;
}
auto wd = window.mutable_dimensions(dim.spatial_dim);
wd->set_size(sharded_lhs_hlo->shape().dimensions(
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
}
for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
if (dim.spatial_dim < 0) {
continue;
}
auto wd = window.mutable_dimensions(dim.spatial_dim);
wd->set_size(sharded_rhs_hlo->shape().dimensions(
conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
wd->set_padding_high(wd->size() - 1);
wd->set_padding_low(wd->size() - 1);
}
TF_ASSIGN_OR_RETURN(Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
/*feature_group_count=*/1,
/*batch_group_count=*/1, window, conv_dnums));
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
return HloInstruction::CreateConvolve(
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
/*feature_group_count=*/1,
/*batch_group_count=*/1, window, conv_dnums, conv.precision_config());
}
} // namespace dot_as_convolution_util
} // namespace xla

View File

@ -0,0 +1,68 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
namespace dot_as_convolution_util {
// Describes the dimensions of a convolution that can be interpreted as a dot.
struct DotGeneralAsConvolutionDimsInfo {
// The dimension numbers for the operands and output corresponding to a
// logical dimension (e.g., batch, contracting, non-contracting). If an
// operand or the output doesn't have the logical dimension, it is set to
// -1.
struct DimNums {
int64 lhs;
int64 rhs;
int64 output;
// The corresponding spatial dimension in the convolution's config. Set to
// -1 if it's not mapped to a spatial dimension.
int64 spatial_dim;
};
std::vector<DimNums> batch_dims;
std::vector<DimNums> contracting_dims;
std::vector<DimNums> lhs_non_contracting_dims;
std::vector<DimNums> rhs_non_contracting_dims;
};
// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can
// be interpreted as a dot, or absl::nullopt otherwise.
absl::optional<DotGeneralAsConvolutionDimsInfo> ParseDotGeneralFromConvolution(
const HloInstruction* conv);
// Creates sharded convolution instruction that can be interpreted as a dot.
// This is a utility for per-op partitioners.
// - 'conv' is the original convolution instruction.
// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'.
// - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result
// convolution instruction.
StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(
const HloInstruction& conv,
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
} // namespace dot_as_convolution_util
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// Propagates sharding information around the graph. HLOs that have shardings
// are kept as-is, those that do not have shardings are given shardings based on
// a simple local greedy heuristic.
class ShardingPropagation : public HloModulePass {
public:
explicit ShardingPropagation(bool is_spmd = false) : is_spmd_(is_spmd) {}
absl::string_view name() const override { return "sharding-propagation"; }
StatusOr<bool> Run(HloModule* module) override;
// Function which can be used to apply a spatially partitioned sharding onto a
// given domain. It will apply the sharding into the exit edges of the domain
// and then rely on the rest of sharding propagation to ensure that the
// intermediate nodes get the correct sharding.
static Status NormalizeDomain(const DomainMetadata::Domain& domain,
const DomainMetadata* metadata);
private:
bool is_spmd_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client/lib:comparators",
"//tensorflow/compiler/xla/service:dot_as_convolution_util",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -2905,6 +2906,46 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
}
Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo);
if (dot_dnums) {
// Use HandleDotHelper() for convs that are actually einsums.
spmd::DotGeneralDimsMapping mapping;
for (const auto& dims : dot_dnums->batch_dims) {
mapping.batch_dims.emplace_back();
mapping.batch_dims.back().lhs = dims.lhs;
mapping.batch_dims.back().rhs = dims.rhs;
mapping.batch_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->contracting_dims) {
mapping.contracting_dims.emplace_back();
mapping.contracting_dims.back().lhs = dims.lhs;
mapping.contracting_dims.back().rhs = dims.rhs;
mapping.contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->lhs_non_contracting_dims) {
mapping.lhs_non_contracting_dims.emplace_back();
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.lhs_non_contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->rhs_non_contracting_dims) {
mapping.rhs_non_contracting_dims.emplace_back();
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.rhs_non_contracting_dims.back().output = dims.output;
}
auto create_sharded_conv =
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_conv,
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
*hlo, *dot_dnums, lhs_hlo, rhs_hlo));
return b->AddInstruction(std::move(sharded_conv));
};
return HandleDotHelper(hlo, mapping, create_sharded_conv);
}
auto lhs = GetPartitionedHlo(hlo->operand(0));
auto rhs = GetPartitionedHlo(hlo->operand(1));
const HloSharding& sharding = hlo->sharding();