[TF:XLA] Add error message about the non-strict XLA jit.
PiperOrigin-RevId: 232836591
This commit is contained in:
parent
42d8403ace
commit
dde1ed7309
@ -282,7 +282,6 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
@ -465,7 +464,6 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
@ -15,13 +15,17 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -36,6 +40,25 @@ namespace {
|
||||
|
||||
const char* const kXlaClusterOutput = "XlaClusterOutput";
|
||||
|
||||
bool IsCpuGpuCompile(const Graph* graph) {
|
||||
for (Node* n : graph->nodes()) {
|
||||
string name;
|
||||
// Only consider nodes being compiled.
|
||||
if (!GetNodeAttr(n->attrs(),
|
||||
EncapsulateXlaComputationsPass::kXlaClusterAttr, &name)
|
||||
.ok())
|
||||
continue;
|
||||
// Early return for any node with a device that is not a CPU or GPU.
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (DeviceNameUtils::ParseFullName(n->def().device(), &parsed)) {
|
||||
if (parsed.type != DEVICE_CPU && parsed.type != DEVICE_GPU) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks if a graph node is marked to be a guaranteed constant.
|
||||
bool is_guaranteed_constant(const Node& n) {
|
||||
bool guaranteed_constant = false;
|
||||
@ -352,12 +375,19 @@ Status EncapsulateXlaComputationsPass::Run(
|
||||
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
|
||||
**options.graph, options.flib_def);
|
||||
|
||||
TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
|
||||
const char* additional_help =
|
||||
IsCpuGpuCompile(options.graph->get())
|
||||
? xla::status_macros::kPossibleAutoJitAlternative
|
||||
: "";
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(Encapsulate(options.graph, options.flib_def),
|
||||
additional_help);
|
||||
VLOG(1) << "EncapsulateXlaComputations() half-way: "
|
||||
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
|
||||
**options.graph, options.flib_def);
|
||||
|
||||
TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(BuildXlaLaunchOps(options.graph->get()),
|
||||
additional_help);
|
||||
VLOG(1) << "EncapsulateXlaComputations() finished: "
|
||||
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
|
||||
**options.graph, options.flib_def);
|
||||
|
@ -19,6 +19,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -35,6 +36,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
@ -304,10 +307,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
|
||||
constants_, /*lazy=*/false, &client,
|
||||
&variables, &kernel, &executable));
|
||||
{
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false,
|
||||
&client, &variables, &kernel, &executable);
|
||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
||||
// Suggest auto jit if the failure was with GPU or CPU.
|
||||
errors::AppendToMessage(&s,
|
||||
xla::status_macros::kPossibleAutoJitAlternative);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
@ -150,8 +150,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
@ -194,7 +192,6 @@ cc_library(
|
||||
":types",
|
||||
":util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
@ -833,7 +830,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -25,6 +25,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace status_macros {
|
||||
|
||||
ABSL_CONST_INIT const char kPossibleAutoJitAlternative[] =
|
||||
"This error might be occurring with the use of xla.compile. If it is not "
|
||||
"necessary that every Op be compiled with XLA, an alternative is to use "
|
||||
"auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment "
|
||||
"variable TF_XLA_FLAGS=\"tf_xla_auto_jit=2\" which will attempt to use xla "
|
||||
"to compile as much of the graph as the compiler is able to.";
|
||||
|
||||
static Status MakeStatus(tensorflow::error::Code code, const string& message) {
|
||||
return Status(code, message);
|
||||
}
|
||||
|
@ -30,6 +30,10 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace status_macros {
|
||||
|
||||
// This is a useful error message when encountering XLA Compiler errors that
|
||||
// could be handled with the non-strict AutoJit mode.
|
||||
extern const char kPossibleAutoJitAlternative[];
|
||||
|
||||
// Stream object used to collect error messages in MAKE_ERROR macros
|
||||
// or append error messages with APPEND_ERROR. It accepts any
|
||||
// arguments with operator<< to build an error string, and then has an
|
||||
|
Loading…
Reference in New Issue
Block a user