Return an error if the user enabled XLA without linking it in.
PiperOrigin-RevId: 221873554
This commit is contained in:
parent
a460218961
commit
6b5bef9216
@ -45,15 +45,37 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "register_xla_cpu_jit",
|
||||||
|
srcs = ["register_xla_cpu_jit.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:tf_xla_stub",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_cpu_jit",
|
name = "xla_cpu_jit",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":register_xla_cpu_jit",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
|
"//tensorflow/core:tf_xla_stub",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "register_xla_gpu_jit",
|
||||||
|
srcs = ["register_xla_gpu_jit.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:tf_xla_stub",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -63,6 +85,8 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = if_cuda([
|
deps = if_cuda([
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":register_xla_gpu_jit",
|
||||||
|
"//tensorflow/core:tf_xla_stub",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
@ -78,6 +102,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":flags",
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":register_xla_cpu_jit",
|
||||||
":xla_device",
|
":xla_device",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
@ -96,6 +121,7 @@ cc_library(
|
|||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
|
":register_xla_gpu_jit",
|
||||||
":xla_device",
|
":xla_device",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
22
tensorflow/compiler/jit/register_xla_cpu_jit.cc
Normal file
22
tensorflow/compiler/jit/register_xla_cpu_jit.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
XlaCpuJitIsLinkedIn register_xla_cpu_jit;
|
||||||
|
}
|
||||||
|
} // namespace tensorflow
|
22
tensorflow/compiler/jit/register_xla_gpu_jit.cc
Normal file
22
tensorflow/compiler/jit/register_xla_gpu_jit.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
XlaGpuJitIsLinkedIn register_xla_gpu_jit;
|
||||||
|
}
|
||||||
|
} // namespace tensorflow
|
@ -2976,6 +2976,19 @@ tf_cuda_library(
|
|||||||
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
|
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "tf_xla_stub",
|
||||||
|
srcs = ["common_runtime/tf_xla_stub.cc"],
|
||||||
|
hdrs = ["common_runtime/tf_xla_stub.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":lib",
|
||||||
|
":proto_text",
|
||||||
|
":session_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "core_cpu_internal",
|
name = "core_cpu_internal",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -2989,6 +3002,7 @@ tf_cuda_library(
|
|||||||
":framework",
|
":framework",
|
||||||
":graph",
|
":graph",
|
||||||
":lib",
|
":lib",
|
||||||
|
":tf_xla_stub",
|
||||||
":proto_text",
|
":proto_text",
|
||||||
":protos_all_cc",
|
":protos_all_cc",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/placer.h"
|
#include "tensorflow/core/common_runtime/placer.h"
|
||||||
|
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||||
#include "tensorflow/core/framework/graph_def_util.h"
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
@ -720,6 +721,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
|||||||
CHECK_EQ(options.callable_options.fetch_size(),
|
CHECK_EQ(options.callable_options.fetch_size(),
|
||||||
rewrite_metadata.fetch_types.size());
|
rewrite_metadata.fetch_types.size());
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(CheckXlaJitOptimizerOptions(session_options_));
|
||||||
|
|
||||||
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
|
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
|
||||||
GraphOptimizationPassOptions optimization_options;
|
GraphOptimizationPassOptions optimization_options;
|
||||||
optimization_options.session_options = session_options_;
|
optimization_options.session_options = session_options_;
|
||||||
|
82
tensorflow/core/common_runtime/tf_xla_stub.cc
Normal file
82
tensorflow/core/common_runtime/tf_xla_stub.cc
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool is_xla_gpu_jit_registered = false;
|
||||||
|
bool is_xla_cpu_jit_registered = false;
|
||||||
|
|
||||||
|
struct XlaEnvVars {
|
||||||
|
bool xla_flags_env_var_present;
|
||||||
|
bool tf_xla_flags_env_var_present;
|
||||||
|
};
|
||||||
|
|
||||||
|
XlaEnvVars ComputeEnvVarHasXlaFlags() {
|
||||||
|
XlaEnvVars env_vars;
|
||||||
|
env_vars.xla_flags_env_var_present = getenv("XLA_FLAGS") != nullptr;
|
||||||
|
env_vars.tf_xla_flags_env_var_present = getenv("TF_XLA_FLAGS") != nullptr;
|
||||||
|
return env_vars;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
XlaGpuJitIsLinkedIn::XlaGpuJitIsLinkedIn() { is_xla_gpu_jit_registered = true; }
|
||||||
|
XlaCpuJitIsLinkedIn::XlaCpuJitIsLinkedIn() { is_xla_cpu_jit_registered = true; }
|
||||||
|
|
||||||
|
Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options) {
|
||||||
|
static XlaEnvVars env_vars = ComputeEnvVarHasXlaFlags();
|
||||||
|
|
||||||
|
if (is_xla_cpu_jit_registered || is_xla_gpu_jit_registered) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (env_vars.xla_flags_env_var_present) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"The XLA JIT is not linked in but the \"XLA_FLAGS\" environment "
|
||||||
|
"variable is set. Please either link in XLA or remove \"XLA_FLAGS\" "
|
||||||
|
"from the environment.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (env_vars.tf_xla_flags_env_var_present) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"The XLA JIT is not linked in but the \"TF_XLA_FLAGS\" environment "
|
||||||
|
"variable is set. Please either link in XLA or remove "
|
||||||
|
"\"TF_XLA_FLAGS\" from the environment.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (session_options) {
|
||||||
|
OptimizerOptions::GlobalJitLevel jit_level =
|
||||||
|
session_options->config.graph_options()
|
||||||
|
.optimizer_options()
|
||||||
|
.global_jit_level();
|
||||||
|
|
||||||
|
if (jit_level == OptimizerOptions::ON_1 ||
|
||||||
|
jit_level == OptimizerOptions::ON_2) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"The XLA JIT is enabled in the session options but XLA is not linked "
|
||||||
|
"in. Plesae either link in XLA or disable the JIT in the session "
|
||||||
|
"options.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace tensorflow
|
50
tensorflow/core/common_runtime/tf_xla_stub.h
Normal file
50
tensorflow/core/common_runtime/tf_xla_stub.h
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
// Returns an error if the XLA JIT is enabled via `session_options` or if the
|
||||||
|
// TF_XLA_FLAGS or XLA_FLAGS environment variables are set, but neither the
|
||||||
|
// XLA CPU JIT nor the XLA GPU JIT are linked in.
|
||||||
|
//
|
||||||
|
// If `session_options` is null then only the environment variables are checked.
|
||||||
|
Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options);
|
||||||
|
|
||||||
|
// The XLA CPU JIT creates a static instance of this class to notify
|
||||||
|
// `CheckXlaJitOptimizerOptions` that the XLA CPU JIT is linked in.
|
||||||
|
//
|
||||||
|
// NB! The constructor of this class (if run at all) needs to be ordered (via
|
||||||
|
// happens before) before any call to `CheckXlaJitOptimizerOptions`.
|
||||||
|
class XlaCpuJitIsLinkedIn {
|
||||||
|
public:
|
||||||
|
XlaCpuJitIsLinkedIn();
|
||||||
|
};
|
||||||
|
|
||||||
|
// The XLA GPU JIT creates a static instance of this class to notify
|
||||||
|
// `CheckXlaJitOptimizerOptions` that the XLA GPU JIT is linked in.
|
||||||
|
//
|
||||||
|
// NB! The constructor of this class (if run at all) needs to be ordered (via
|
||||||
|
// happens before) before any call to `CheckXlaJitOptimizerOptions`.
|
||||||
|
class XlaGpuJitIsLinkedIn {
|
||||||
|
public:
|
||||||
|
XlaGpuJitIsLinkedIn();
|
||||||
|
};
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
Loading…
Reference in New Issue
Block a user