[XLA:GPU] Print thunk kind in thunk schedule.

Particularly helpful for distinguishing between kWhile and kFor loops.  We've
got this info in the backend-config, but that's not as obvious.

PiperOrigin-RevId: 238470458
This commit is contained in:
Justin Lebar 2019-03-14 10:31:22 -07:00 committed by TensorFlower Gardener
parent 32edfdd8e4
commit 425d4f2089
4 changed files with 53 additions and 23 deletions

View File

@ -372,7 +372,7 @@ cc_library(
":hlo_execution_profiler", ":hlo_execution_profiler",
":infeed_manager", ":infeed_manager",
":ir_emission_utils", ":ir_emission_utils",
":nccl_all_reduce_thunk", ":nccl_all_reduce_thunk", # fixdeps: keep
":outfeed_manager", ":outfeed_manager",
":partition_assignment", ":partition_assignment",
":stream_assignment", ":stream_assignment",
@ -407,6 +407,7 @@ cc_library(
"//tensorflow/stream_executor", "//tensorflow/stream_executor",
"//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:blas",
"//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:device_memory",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",

View File

@ -18,48 +18,52 @@ limitations under the License.
namespace xla { namespace xla {
namespace gpu { namespace gpu {
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { absl::string_view ThunkKindToString(Thunk::Kind kind) {
switch (kind) { switch (kind) {
case Thunk::kCholesky: case Thunk::kCholesky:
return os << "kCholesky"; return "kCholesky";
case Thunk::kConditional: case Thunk::kConditional:
return os << "kConditional"; return "kConditional";
case Thunk::kConvolution: case Thunk::kConvolution:
return os << "kConvolution"; return "kConvolution";
case Thunk::kCopy: case Thunk::kCopy:
return os << "kCopy"; return "kCopy";
case Thunk::kCudnnBatchNormBackward: case Thunk::kCudnnBatchNormBackward:
return os << "kCudnnBatchNormBackward"; return "kCudnnBatchNormBackward";
case Thunk::kCudnnBatchNormForwardInference: case Thunk::kCudnnBatchNormForwardInference:
return os << "kCudnnBatchNormForwardInference"; return "kCudnnBatchNormForwardInference";
case Thunk::kCudnnBatchNormForwardTraining: case Thunk::kCudnnBatchNormForwardTraining:
return os << "kCudnnBatchNormForwardTraining"; return "kCudnnBatchNormForwardTraining";
case Thunk::kNcclAllReduce: case Thunk::kNcclAllReduce:
return os << "kNcclAllReduce"; return "kNcclAllReduce";
case Thunk::kFft: case Thunk::kFft:
return os << "kFft"; return "kFft";
case Thunk::kGemm: case Thunk::kGemm:
return os << "kGemm"; return "kGemm";
case Thunk::kInfeed: case Thunk::kInfeed:
return os << "kInfeed"; return "kInfeed";
case Thunk::kKernel: case Thunk::kKernel:
return os << "kKernel"; return "kKernel";
case Thunk::kMemset32BitValue: case Thunk::kMemset32BitValue:
return os << "kMemset32BitValue"; return "kMemset32BitValue";
case Thunk::kMemzero: case Thunk::kMemzero:
return os << "kMemzero"; return "kMemzero";
case Thunk::kOutfeed: case Thunk::kOutfeed:
return os << "kOutfeed"; return "kOutfeed";
case Thunk::kSequential: case Thunk::kSequential:
return os << "kSequential"; return "kSequential";
case Thunk::kTriangularSolve: case Thunk::kTriangularSolve:
return os << "kTriangularSolve"; return "kTriangularSolve";
case Thunk::kTuple: case Thunk::kTuple:
return os << "kTuple"; return "kTuple";
case Thunk::kWhile: case Thunk::kWhile:
return os << "kWhile"; return "kWhile";
} }
} }
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
return os << ThunkKindToString(kind);
}
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -106,6 +106,7 @@ class Thunk {
// A sequence of thunks. // A sequence of thunks.
using ThunkSequence = std::vector<std::unique_ptr<Thunk>>; using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
absl::string_view ThunkKindToString(Thunk::Kind);
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
} // namespace gpu } // namespace gpu

View File

@ -14,7 +14,10 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
#include <algorithm>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
@ -144,11 +147,32 @@ const std::list<const Thunk*>& ThunkSchedule::DependsOn(
} }
string ThunkSchedule::ToString() const { string ThunkSchedule::ToString() const {
if (thunk_total_order_.empty()) {
return "No thunks.";
}
const Thunk* thunk_with_longest_kind = *absl::c_max_element(
thunk_total_order_, [](const Thunk* a, const Thunk* b) {
return ThunkKindToString(a->kind()).length() <
ThunkKindToString(b->kind()).length();
});
int64 max_thunk_kind_len =
ThunkKindToString(thunk_with_longest_kind->kind()).length();
string result = "Total order:\n"; string result = "Total order:\n";
for (Thunk* thunk : thunk_total_order_) { for (Thunk* thunk : thunk_total_order_) {
absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); // Write out the thunk kind, padded out to max_thunk_kind_len.
absl::string_view kind_str = ThunkKindToString(thunk->kind());
absl::StrAppend(&result, kind_str,
string(max_thunk_kind_len - kind_str.length(), ' '), "\t");
if (thunk->hlo_instruction() != nullptr) {
absl::StrAppend(&result, thunk->hlo_instruction()->ToString());
} else {
absl::StrAppend(&result, "(no HloInstruction)");
} }
absl::StrAppend(&result, "Dependencies:\n"); absl::StrAppend(&result, "\n");
}
absl::StrAppend(&result, "\nDependencies:\n");
for (const auto& entry : depends_on_) { for (const auto& entry : depends_on_) {
const Thunk* dependent = entry.first; const Thunk* dependent = entry.first;
for (const Thunk* dependency : entry.second) { for (const Thunk* dependency : entry.second) {