Yuanzhong Xu f460141434 [XLA] Move sharding propagation to third party
This also moves some utilities of interpreting convolutions as dots.

PiperOrigin-RevId: 312868839
Change-Id: I90bdc30217edf6dfb301a9c80b7155653391fa1a
2020-05-22 18:18:30 -07:00

51 lines
2.0 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// Propagates sharding information around the graph. HLOs that have shardings
// are kept as-is, those that do not have shardings are given shardings based on
// a simple local greedy heuristic.
class ShardingPropagation : public HloModulePass {
public:
explicit ShardingPropagation(bool is_spmd = false) : is_spmd_(is_spmd) {}
absl::string_view name() const override { return "sharding-propagation"; }
StatusOr<bool> Run(HloModule* module) override;
// Function which can be used to apply a spatially partitioned sharding onto a
// given domain. It will apply the sharding into the exit edges of the domain
// and then rely on the rest of sharding propagation to ensure that the
// intermediate nodes get the correct sharding.
static Status NormalizeDomain(const DomainMetadata::Domain& domain,
const DomainMetadata* metadata);
private:
bool is_spmd_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_