[TF:XLA] Add error message about the non-strict XLA jit.

PiperOrigin-RevId: 232836591
This commit is contained in:
A. Unique TensorFlower 2019-02-07 02:35:14 -08:00 committed by TensorFlower Gardener
parent 42d8403ace
commit dde1ed7309
7 changed files with 60 additions and 12 deletions

View File

@ -282,7 +282,6 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
@ -465,7 +464,6 @@ cc_library(
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",

View File

@ -15,13 +15,17 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -36,6 +40,25 @@ namespace {
const char* const kXlaClusterOutput = "XlaClusterOutput";
bool IsCpuGpuCompile(const Graph* graph) {
for (Node* n : graph->nodes()) {
string name;
// Only consider nodes being compiled.
if (!GetNodeAttr(n->attrs(),
EncapsulateXlaComputationsPass::kXlaClusterAttr, &name)
.ok())
continue;
// Early return for any node with a device that is not a CPU or GPU.
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(n->def().device(), &parsed)) {
if (parsed.type != DEVICE_CPU && parsed.type != DEVICE_GPU) {
return false;
}
}
}
return true;
}
// Checks if a graph node is marked to be a guaranteed constant.
bool is_guaranteed_constant(const Node& n) {
bool guaranteed_constant = false;
@ -352,12 +375,19 @@ Status EncapsulateXlaComputationsPass::Run(
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
**options.graph, options.flib_def);
TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
const char* additional_help =
IsCpuGpuCompile(options.graph->get())
? xla::status_macros::kPossibleAutoJitAlternative
: "";
TF_RETURN_WITH_CONTEXT_IF_ERROR(Encapsulate(options.graph, options.flib_def),
additional_help);
VLOG(1) << "EncapsulateXlaComputations() half-way: "
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
**options.graph, options.flib_def);
TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
TF_RETURN_WITH_CONTEXT_IF_ERROR(BuildXlaLaunchOps(options.graph->get()),
additional_help);
VLOG(1) << "EncapsulateXlaComputations() finished: "
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
**options.graph, options.flib_def);

View File

@ -19,6 +19,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/function.h"
@ -35,6 +36,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/util/stream_executor_util.h"
@ -304,10 +307,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
OP_REQUIRES_OK(
ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
constants_, /*lazy=*/false, &client,
&variables, &kernel, &executable));
{
Status s = CompileToLocalExecutable(
ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false,
&client, &variables, &kernel, &executable);
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
// Suggest auto jit if the failure was with GPU or CPU.
errors::AppendToMessage(&s,
xla::status_macros::kPossibleAutoJitAlternative);
}
OP_REQUIRES_OK(ctx, s);
}
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;

View File

@ -150,8 +150,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":status",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/stream_executor/lib",
],
)
@ -194,7 +192,6 @@ cc_library(
":types",
":util",
"//tensorflow/core:lib",
"@com_google_absl//absl/synchronization",
],
)
@ -833,7 +830,6 @@ cc_library(
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)

View File

@ -25,6 +25,13 @@ limitations under the License.
namespace xla {
namespace status_macros {
ABSL_CONST_INIT const char kPossibleAutoJitAlternative[] =
"This error might be occurring with the use of xla.compile. If it is not "
"necessary that every Op be compiled with XLA, an alternative is to use "
"auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment "
"variable TF_XLA_FLAGS=\"tf_xla_auto_jit=2\" which will attempt to use xla "
"to compile as much of the graph as the compiler is able to.";
static Status MakeStatus(tensorflow::error::Code code, const string& message) {
return Status(code, message);
}

View File

@ -30,6 +30,10 @@ limitations under the License.
namespace xla {
namespace status_macros {
// This is a useful error message when encountering XLA Compiler errors that
// could be handled with the non-strict AutoJit mode.
extern const char kPossibleAutoJitAlternative[];
// Stream object used to collect error messages in MAKE_ERROR macros
// or append error messages with APPEND_ERROR. It accepts any
// arguments with operator<< to build an error string, and then has an