Return an error if the user enabled XLA without linking it in.
PiperOrigin-RevId: 221873554
This commit is contained in:
parent
a460218961
commit
6b5bef9216
@ -45,15 +45,37 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "register_xla_cpu_jit",
|
||||
srcs = ["register_xla_cpu_jit.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:tf_xla_stub",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_cpu_jit",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":register_xla_cpu_jit",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:tf_xla_stub",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "register_xla_gpu_jit",
|
||||
srcs = ["register_xla_gpu_jit.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:tf_xla_stub",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -63,6 +85,8 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda([
|
||||
":jit_compilation_passes",
|
||||
":register_xla_gpu_jit",
|
||||
"//tensorflow/core:tf_xla_stub",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
@ -78,6 +102,7 @@ cc_library(
|
||||
deps = [
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":register_xla_cpu_jit",
|
||||
":xla_device",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
@ -96,6 +121,7 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":register_xla_gpu_jit",
|
||||
":xla_device",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
|
22
tensorflow/compiler/jit/register_xla_cpu_jit.cc
Normal file
22
tensorflow/compiler/jit/register_xla_cpu_jit.cc
Normal file
@ -0,0 +1,22 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
XlaCpuJitIsLinkedIn register_xla_cpu_jit;
|
||||
}
|
||||
} // namespace tensorflow
|
22
tensorflow/compiler/jit/register_xla_gpu_jit.cc
Normal file
22
tensorflow/compiler/jit/register_xla_gpu_jit.cc
Normal file
@ -0,0 +1,22 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
XlaGpuJitIsLinkedIn register_xla_gpu_jit;
|
||||
}
|
||||
} // namespace tensorflow
|
@ -2976,6 +2976,19 @@ tf_cuda_library(
|
||||
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "tf_xla_stub",
|
||||
srcs = ["common_runtime/tf_xla_stub.cc"],
|
||||
hdrs = ["common_runtime/tf_xla_stub.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lib",
|
||||
":proto_text",
|
||||
":session_options",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_internal",
|
||||
srcs = [
|
||||
@ -2989,6 +3002,7 @@ tf_cuda_library(
|
||||
":framework",
|
||||
":graph",
|
||||
":lib",
|
||||
":tf_xla_stub",
|
||||
":proto_text",
|
||||
":protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/placer.h"
|
||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -720,6 +721,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
||||
CHECK_EQ(options.callable_options.fetch_size(),
|
||||
rewrite_metadata.fetch_types.size());
|
||||
|
||||
TF_RETURN_IF_ERROR(CheckXlaJitOptimizerOptions(session_options_));
|
||||
|
||||
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
|
||||
GraphOptimizationPassOptions optimization_options;
|
||||
optimization_options.session_options = session_options_;
|
||||
|
82
tensorflow/core/common_runtime/tf_xla_stub.cc
Normal file
82
tensorflow/core/common_runtime/tf_xla_stub.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
bool is_xla_gpu_jit_registered = false;
|
||||
bool is_xla_cpu_jit_registered = false;
|
||||
|
||||
struct XlaEnvVars {
|
||||
bool xla_flags_env_var_present;
|
||||
bool tf_xla_flags_env_var_present;
|
||||
};
|
||||
|
||||
XlaEnvVars ComputeEnvVarHasXlaFlags() {
|
||||
XlaEnvVars env_vars;
|
||||
env_vars.xla_flags_env_var_present = getenv("XLA_FLAGS") != nullptr;
|
||||
env_vars.tf_xla_flags_env_var_present = getenv("TF_XLA_FLAGS") != nullptr;
|
||||
return env_vars;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaGpuJitIsLinkedIn::XlaGpuJitIsLinkedIn() { is_xla_gpu_jit_registered = true; }
|
||||
XlaCpuJitIsLinkedIn::XlaCpuJitIsLinkedIn() { is_xla_cpu_jit_registered = true; }
|
||||
|
||||
Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options) {
|
||||
static XlaEnvVars env_vars = ComputeEnvVarHasXlaFlags();
|
||||
|
||||
if (is_xla_cpu_jit_registered || is_xla_gpu_jit_registered) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (env_vars.xla_flags_env_var_present) {
|
||||
return errors::InvalidArgument(
|
||||
"The XLA JIT is not linked in but the \"XLA_FLAGS\" environment "
|
||||
"variable is set. Please either link in XLA or remove \"XLA_FLAGS\" "
|
||||
"from the environment.");
|
||||
}
|
||||
|
||||
if (env_vars.tf_xla_flags_env_var_present) {
|
||||
return errors::InvalidArgument(
|
||||
"The XLA JIT is not linked in but the \"TF_XLA_FLAGS\" environment "
|
||||
"variable is set. Please either link in XLA or remove "
|
||||
"\"TF_XLA_FLAGS\" from the environment.");
|
||||
}
|
||||
|
||||
if (session_options) {
|
||||
OptimizerOptions::GlobalJitLevel jit_level =
|
||||
session_options->config.graph_options()
|
||||
.optimizer_options()
|
||||
.global_jit_level();
|
||||
|
||||
if (jit_level == OptimizerOptions::ON_1 ||
|
||||
jit_level == OptimizerOptions::ON_2) {
|
||||
return errors::InvalidArgument(
|
||||
"The XLA JIT is enabled in the session options but XLA is not linked "
|
||||
"in. Plesae either link in XLA or disable the JIT in the session "
|
||||
"options.");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tensorflow
|
50
tensorflow/core/common_runtime/tf_xla_stub.h
Normal file
50
tensorflow/core/common_runtime/tf_xla_stub.h
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Returns an error if the XLA JIT is enabled via `session_options` or if the
|
||||
// TF_XLA_FLAGS or XLA_FLAGS environment variables are set, but neither the
|
||||
// XLA CPU JIT nor the XLA GPU JIT are linked in.
|
||||
//
|
||||
// If `session_options` is null then only the environment variables are checked.
|
||||
Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options);
|
||||
|
||||
// The XLA CPU JIT creates a static instance of this class to notify
|
||||
// `CheckXlaJitOptimizerOptions` that the XLA CPU JIT is linked in.
|
||||
//
|
||||
// NB! The constructor of this class (if run at all) needs to be ordered (via
|
||||
// happens before) before any call to `CheckXlaJitOptimizerOptions`.
|
||||
class XlaCpuJitIsLinkedIn {
|
||||
public:
|
||||
XlaCpuJitIsLinkedIn();
|
||||
};
|
||||
|
||||
// The XLA GPU JIT creates a static instance of this class to notify
|
||||
// `CheckXlaJitOptimizerOptions` that the XLA GPU JIT is linked in.
|
||||
//
|
||||
// NB! The constructor of this class (if run at all) needs to be ordered (via
|
||||
// happens before) before any call to `CheckXlaJitOptimizerOptions`.
|
||||
class XlaGpuJitIsLinkedIn {
|
||||
public:
|
||||
XlaGpuJitIsLinkedIn();
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
Loading…
Reference in New Issue
Block a user