[XLA] Move sharding propagation to third party
This also moves some utilities of interpreting convolutions as dots. PiperOrigin-RevId: 312868839 Change-Id: I90bdc30217edf6dfb301a9c80b7155653391fa1a
This commit is contained in:
parent
18aaa18cf1
commit
f460141434
@ -491,6 +491,66 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sharding_propagation",
|
||||
srcs = [
|
||||
"sharding_propagation.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"sharding_propagation.h",
|
||||
],
|
||||
deps = [
|
||||
":dot_as_convolution_util",
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_pass",
|
||||
":hlo_sharding_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "sharding_propagation_test",
|
||||
srcs = [
|
||||
"sharding_propagation_test.cc",
|
||||
],
|
||||
deps = [
|
||||
"hlo_matchers",
|
||||
":hlo_parser",
|
||||
":sharding_propagation",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dot_as_convolution_util",
|
||||
srcs = [
|
||||
"dot_as_convolution_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"dot_as_convolution_util.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "dynamic_parameter_binding_test",
|
||||
srcs = ["dynamic_parameter_binding_test.cc"],
|
||||
|
||||
139
tensorflow/compiler/xla/service/dot_as_convolution_util.cc
Normal file
139
tensorflow/compiler/xla/service/dot_as_convolution_util.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace dot_as_convolution_util {
|
||||
|
||||
/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
|
||||
ParseDotGeneralFromConvolution(const HloInstruction* conv) {
|
||||
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
|
||||
if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
const auto& conv_dims = conv->convolution_dimension_numbers();
|
||||
DotGeneralAsConvolutionDimsInfo dims;
|
||||
dims.lhs_non_contracting_dims.push_back(
|
||||
{conv_dims.input_batch_dimension(), -1,
|
||||
conv_dims.output_batch_dimension(), -1});
|
||||
dims.rhs_non_contracting_dims.push_back(
|
||||
{-1, conv_dims.kernel_output_feature_dimension(),
|
||||
conv_dims.output_feature_dimension(), -1});
|
||||
dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
|
||||
conv_dims.kernel_input_feature_dimension(),
|
||||
-1, -1});
|
||||
|
||||
for (int64 i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs = conv_dims.input_spatial_dimensions(i);
|
||||
int64 lhs_size = conv->operand(0)->shape().dimensions(lhs);
|
||||
int64 rhs = conv_dims.kernel_spatial_dimensions(i);
|
||||
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
|
||||
int64 output = conv_dims.output_spatial_dimensions(i);
|
||||
const auto& wd = conv->window().dimensions(i);
|
||||
if (lhs_size == wd.size() &&
|
||||
std::max<int64>(1, lhs_size - 1) == wd.stride() &&
|
||||
lhs_size == wd.base_dilation() && wd.window_dilation() == 1 &&
|
||||
wd.padding_high() == 0 && wd.padding_low() == 0 &&
|
||||
!wd.window_reversal()) {
|
||||
// A batch dimension in DotGeneral is represented as a spatial dimension
|
||||
// with window size B (batch dimension size), stride B - 1, and base
|
||||
// dilation B.
|
||||
dims.batch_dims.push_back({lhs, rhs, output, i});
|
||||
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
|
||||
wd.window_dilation() == 1 && wd.padding_high() == 0 &&
|
||||
wd.padding_low() == 0 && !wd.window_reversal()) {
|
||||
// A contracting dimension be represented as a spatial dimension with
|
||||
// window size C (contracting dimension size). Stride can be any size
|
||||
// since there is only one window.
|
||||
dims.contracting_dims.push_back({lhs, rhs, output, i});
|
||||
} else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
|
||||
wd.base_dilation() == 1) {
|
||||
if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
|
||||
wd.padding_low() == 0 && !wd.window_reversal()) {
|
||||
// A LHS non-contracting dimension can be represented as a spatial
|
||||
// dimension with window size 1.
|
||||
dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
|
||||
} else if (lhs_size == 1 && wd.size() == rhs_size &&
|
||||
wd.padding_high() == rhs_size - 1 &&
|
||||
wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
|
||||
// A RHS non-contracting dimension can be represented as a spatial
|
||||
// dimension with window size N (non-contracting dimension size), low
|
||||
// padding N - 1, high padding N - 1 and window reversal.
|
||||
dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
}
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
return dims;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloInstruction>>
|
||||
CreateShardedConvForDotGeneralConvolution(
|
||||
const HloInstruction& conv,
|
||||
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
|
||||
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
|
||||
CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
|
||||
const auto& conv_dnums = conv.convolution_dimension_numbers();
|
||||
auto window = conv.window();
|
||||
for (const auto& dim : dot_dnums.batch_dims) {
|
||||
auto wd = window.mutable_dimensions(dim.spatial_dim);
|
||||
wd->set_size(sharded_lhs_hlo->shape().dimensions(
|
||||
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
for (const auto& dim : dot_dnums.contracting_dims) {
|
||||
if (dim.spatial_dim < 0) {
|
||||
continue;
|
||||
}
|
||||
auto wd = window.mutable_dimensions(dim.spatial_dim);
|
||||
wd->set_size(sharded_lhs_hlo->shape().dimensions(
|
||||
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
|
||||
}
|
||||
for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
|
||||
if (dim.spatial_dim < 0) {
|
||||
continue;
|
||||
}
|
||||
auto wd = window.mutable_dimensions(dim.spatial_dim);
|
||||
wd->set_size(sharded_rhs_hlo->shape().dimensions(
|
||||
conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
|
||||
wd->set_padding_high(wd->size() - 1);
|
||||
wd->set_padding_low(wd->size() - 1);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, conv_dnums));
|
||||
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
|
||||
return HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, conv_dnums, conv.precision_config());
|
||||
}
|
||||
|
||||
} // namespace dot_as_convolution_util
|
||||
} // namespace xla
|
||||
68
tensorflow/compiler/xla/service/dot_as_convolution_util.h
Normal file
68
tensorflow/compiler/xla/service/dot_as_convolution_util.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
||||
namespace xla {
|
||||
namespace dot_as_convolution_util {
|
||||
|
||||
// Describes the dimensions of a convolution that can be interpreted as a dot.
|
||||
struct DotGeneralAsConvolutionDimsInfo {
|
||||
// The dimension numbers for the operands and output corresponding to a
|
||||
// logical dimension (e.g., batch, contracting, non-contracting). If an
|
||||
// operand or the output doesn't have the logical dimension, it is set to
|
||||
// -1.
|
||||
struct DimNums {
|
||||
int64 lhs;
|
||||
int64 rhs;
|
||||
int64 output;
|
||||
// The corresponding spatial dimension in the convolution's config. Set to
|
||||
// -1 if it's not mapped to a spatial dimension.
|
||||
int64 spatial_dim;
|
||||
};
|
||||
std::vector<DimNums> batch_dims;
|
||||
std::vector<DimNums> contracting_dims;
|
||||
std::vector<DimNums> lhs_non_contracting_dims;
|
||||
std::vector<DimNums> rhs_non_contracting_dims;
|
||||
};
|
||||
|
||||
// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can
|
||||
// be interpreted as a dot, or absl::nullopt otherwise.
|
||||
absl::optional<DotGeneralAsConvolutionDimsInfo> ParseDotGeneralFromConvolution(
|
||||
const HloInstruction* conv);
|
||||
|
||||
// Creates sharded convolution instruction that can be interpreted as a dot.
|
||||
// This is a utility for per-op partitioners.
|
||||
// - 'conv' is the original convolution instruction.
|
||||
// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'.
|
||||
// - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result
|
||||
// convolution instruction.
|
||||
StatusOr<std::unique_ptr<HloInstruction>>
|
||||
CreateShardedConvForDotGeneralConvolution(
|
||||
const HloInstruction& conv,
|
||||
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
|
||||
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
|
||||
|
||||
} // namespace dot_as_convolution_util
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
|
||||
1478
tensorflow/compiler/xla/service/sharding_propagation.cc
Normal file
1478
tensorflow/compiler/xla/service/sharding_propagation.cc
Normal file
File diff suppressed because it is too large
Load Diff
50
tensorflow/compiler/xla/service/sharding_propagation.h
Normal file
50
tensorflow/compiler/xla/service/sharding_propagation.h
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Propagates sharding information around the graph. HLOs that have shardings
|
||||
// are kept as-is, those that do not have shardings are given shardings based on
|
||||
// a simple local greedy heuristic.
|
||||
class ShardingPropagation : public HloModulePass {
|
||||
public:
|
||||
explicit ShardingPropagation(bool is_spmd = false) : is_spmd_(is_spmd) {}
|
||||
absl::string_view name() const override { return "sharding-propagation"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Function which can be used to apply a spatially partitioned sharding onto a
|
||||
// given domain. It will apply the sharding into the exit edges of the domain
|
||||
// and then rely on the rest of sharding propagation to ensure that the
|
||||
// intermediate nodes get the correct sharding.
|
||||
static Status NormalizeDomain(const DomainMetadata::Domain& domain,
|
||||
const DomainMetadata* metadata);
|
||||
|
||||
private:
|
||||
bool is_spmd_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
|
||||
1329
tensorflow/compiler/xla/service/sharding_propagation_test.cc
Normal file
1329
tensorflow/compiler/xla/service/sharding_propagation_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -33,6 +33,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||
"//tensorflow/compiler/xla/service:dot_as_convolution_util",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
|
||||
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -2905,6 +2906,46 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
}
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
|
||||
auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo);
|
||||
if (dot_dnums) {
|
||||
// Use HandleDotHelper() for convs that are actually einsums.
|
||||
spmd::DotGeneralDimsMapping mapping;
|
||||
for (const auto& dims : dot_dnums->batch_dims) {
|
||||
mapping.batch_dims.emplace_back();
|
||||
mapping.batch_dims.back().lhs = dims.lhs;
|
||||
mapping.batch_dims.back().rhs = dims.rhs;
|
||||
mapping.batch_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->contracting_dims) {
|
||||
mapping.contracting_dims.emplace_back();
|
||||
mapping.contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.contracting_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->lhs_non_contracting_dims) {
|
||||
mapping.lhs_non_contracting_dims.emplace_back();
|
||||
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.lhs_non_contracting_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->rhs_non_contracting_dims) {
|
||||
mapping.rhs_non_contracting_dims.emplace_back();
|
||||
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.rhs_non_contracting_dims.back().output = dims.output;
|
||||
}
|
||||
auto create_sharded_conv =
|
||||
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
|
||||
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_conv,
|
||||
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
|
||||
*hlo, *dot_dnums, lhs_hlo, rhs_hlo));
|
||||
return b->AddInstruction(std::move(sharded_conv));
|
||||
};
|
||||
return HandleDotHelper(hlo, mapping, create_sharded_conv);
|
||||
}
|
||||
|
||||
auto lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
auto rhs = GetPartitionedHlo(hlo->operand(1));
|
||||
const HloSharding& sharding = hlo->sharding();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user