[XLA] Skip AllReduceCombiner when threshold is 0

PiperOrigin-RevId: 326771299
Change-Id: I248939e5e7c440722c5dd022a25968f956cfaf49
This commit is contained in:
Yuanzhong Xu 2020-08-14 19:30:06 -07:00 committed by TensorFlower Gardener
parent ea0a469bdd
commit 8ce0600f58

View File

@ -268,6 +268,11 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
VLOG(1) << "Running AllReduceCombiner with threshold of "
<< combine_threshold_in_bytes_ << " bytes";
if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
VLOG(1) << "Skip AllReduceCombiner because the threshold is zero";
return false;
}
if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce "
"with constrained layouts";