STT-tensorflow/tensorflow/compiler/xla/service/all_gather_decomposer.h
Yuanzhong Xu 33e390481a [XLA] Fix some all-gather issues.
- Fix a wrong shape inference check.
 - Remove the partition_count argument from AllGatherDecomposer: it is a per-HLO property related to the replica groups.
 - Change ID types from U32 to S32 to match replica ID type.

PiperOrigin-RevId: 312391312
Change-Id: I00ead2e7fd3653c7dbde15fa7b623104a78b9a8c
2020-05-19 17:55:08 -07:00

49 lines
1.9 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// AllGatherDecomposer is a pass which converts unsupported all-gathers into
// dynamic-update-slices and all-reduces.
class AllGatherDecomposer : public HloModulePass {
public:
explicit AllGatherDecomposer(
std::function<bool(const HloAllGatherInstruction&)> should_decompose)
: should_decompose_(std::move(should_decompose)) {}
AllGatherDecomposer()
: should_decompose_(
[](const HloAllGatherInstruction& ag) { return true; }) {}
absl::string_view name() const override { return "all_gather_decomposer"; }
// Run AllGatherDecomposer pass on computations in 'module'.
// Returns whether the 'module' was changed.
StatusOr<bool> Run(HloModule* module) override;
private:
std::function<bool(const HloAllGatherInstruction&)> should_decompose_;
int64 partition_count_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_