[XLA:GPU] Print thunk kind in thunk schedule.

Particularly helpful for distinguishing between kWhile and kFor loops.  We've
got this info in the backend-config, but that's not as obvious.

PiperOrigin-RevId: 238470458
This commit is contained in:
Justin Lebar 2019-03-14 10:31:22 -07:00 committed by TensorFlower Gardener
parent 32edfdd8e4
commit 425d4f2089
4 changed files with 53 additions and 23 deletions

View File

@ -372,7 +372,7 @@ cc_library(
":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",
":nccl_all_reduce_thunk",
":nccl_all_reduce_thunk", # fixdeps: keep
":outfeed_manager",
":partition_assignment",
":stream_assignment",
@ -407,6 +407,7 @@ cc_library(
"//tensorflow/stream_executor",
"//tensorflow/stream_executor:blas",
"//tensorflow/stream_executor:device_memory",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -18,48 +18,52 @@ limitations under the License.
namespace xla {
namespace gpu {
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
absl::string_view ThunkKindToString(Thunk::Kind kind) {
switch (kind) {
case Thunk::kCholesky:
return os << "kCholesky";
return "kCholesky";
case Thunk::kConditional:
return os << "kConditional";
return "kConditional";
case Thunk::kConvolution:
return os << "kConvolution";
return "kConvolution";
case Thunk::kCopy:
return os << "kCopy";
return "kCopy";
case Thunk::kCudnnBatchNormBackward:
return os << "kCudnnBatchNormBackward";
return "kCudnnBatchNormBackward";
case Thunk::kCudnnBatchNormForwardInference:
return os << "kCudnnBatchNormForwardInference";
return "kCudnnBatchNormForwardInference";
case Thunk::kCudnnBatchNormForwardTraining:
return os << "kCudnnBatchNormForwardTraining";
return "kCudnnBatchNormForwardTraining";
case Thunk::kNcclAllReduce:
return os << "kNcclAllReduce";
return "kNcclAllReduce";
case Thunk::kFft:
return os << "kFft";
return "kFft";
case Thunk::kGemm:
return os << "kGemm";
return "kGemm";
case Thunk::kInfeed:
return os << "kInfeed";
return "kInfeed";
case Thunk::kKernel:
return os << "kKernel";
return "kKernel";
case Thunk::kMemset32BitValue:
return os << "kMemset32BitValue";
return "kMemset32BitValue";
case Thunk::kMemzero:
return os << "kMemzero";
return "kMemzero";
case Thunk::kOutfeed:
return os << "kOutfeed";
return "kOutfeed";
case Thunk::kSequential:
return os << "kSequential";
return "kSequential";
case Thunk::kTriangularSolve:
return os << "kTriangularSolve";
return "kTriangularSolve";
case Thunk::kTuple:
return os << "kTuple";
return "kTuple";
case Thunk::kWhile:
return os << "kWhile";
return "kWhile";
}
}
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
return os << ThunkKindToString(kind);
}
} // namespace gpu
} // namespace xla

View File

@ -106,6 +106,7 @@ class Thunk {
// A sequence of thunks.
using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
absl::string_view ThunkKindToString(Thunk::Kind);
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
} // namespace gpu

View File

@ -14,7 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
#include <algorithm>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
@ -144,11 +147,32 @@ const std::list<const Thunk*>& ThunkSchedule::DependsOn(
}
string ThunkSchedule::ToString() const {
if (thunk_total_order_.empty()) {
return "No thunks.";
}
const Thunk* thunk_with_longest_kind = *absl::c_max_element(
thunk_total_order_, [](const Thunk* a, const Thunk* b) {
return ThunkKindToString(a->kind()).length() <
ThunkKindToString(b->kind()).length();
});
int64 max_thunk_kind_len =
ThunkKindToString(thunk_with_longest_kind->kind()).length();
string result = "Total order:\n";
for (Thunk* thunk : thunk_total_order_) {
absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n");
// Write out the thunk kind, padded out to max_thunk_kind_len.
absl::string_view kind_str = ThunkKindToString(thunk->kind());
absl::StrAppend(&result, kind_str,
string(max_thunk_kind_len - kind_str.length(), ' '), "\t");
if (thunk->hlo_instruction() != nullptr) {
absl::StrAppend(&result, thunk->hlo_instruction()->ToString());
} else {
absl::StrAppend(&result, "(no HloInstruction)");
}
absl::StrAppend(&result, "\n");
}
absl::StrAppend(&result, "Dependencies:\n");
absl::StrAppend(&result, "\nDependencies:\n");
for (const auto& entry : depends_on_) {
const Thunk* dependent = entry.first;
for (const Thunk* dependency : entry.second) {