parent
209c7dae93
commit
0c610a2be5
@ -45,37 +45,15 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "register_xla_cpu_jit",
|
|
||||||
srcs = ["register_xla_cpu_jit.cc"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:tf_xla_stub",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_cpu_jit",
|
name = "xla_cpu_jit",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":register_xla_cpu_jit",
|
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
"//tensorflow/core:tf_xla_stub",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "register_xla_gpu_jit",
|
|
||||||
srcs = ["register_xla_gpu_jit.cc"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:tf_xla_stub",
|
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -85,8 +63,6 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = if_cuda([
|
deps = if_cuda([
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":register_xla_gpu_jit",
|
|
||||||
"//tensorflow/core:tf_xla_stub",
|
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
@ -102,7 +78,6 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":flags",
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":register_xla_cpu_jit",
|
|
||||||
":xla_device",
|
":xla_device",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
@ -121,7 +96,6 @@ cc_library(
|
|||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":register_xla_gpu_jit",
|
|
||||||
":xla_device",
|
":xla_device",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
XlaCpuJitIsLinkedIn register_xla_cpu_jit;
|
|
||||||
}
|
|
||||||
} // namespace tensorflow
|
|
@ -1,22 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
XlaGpuJitIsLinkedIn register_xla_gpu_jit;
|
|
||||||
}
|
|
||||||
} // namespace tensorflow
|
|
@ -2976,19 +2976,6 @@ tf_cuda_library(
|
|||||||
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
|
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
|
||||||
name = "tf_xla_stub",
|
|
||||||
srcs = ["common_runtime/tf_xla_stub.cc"],
|
|
||||||
hdrs = ["common_runtime/tf_xla_stub.h"],
|
|
||||||
copts = tf_copts(),
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [
|
|
||||||
":lib",
|
|
||||||
":proto_text",
|
|
||||||
":session_options",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "core_cpu_internal",
|
name = "core_cpu_internal",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -3002,7 +2989,6 @@ tf_cuda_library(
|
|||||||
":framework",
|
":framework",
|
||||||
":graph",
|
":graph",
|
||||||
":lib",
|
":lib",
|
||||||
":tf_xla_stub",
|
|
||||||
":proto_text",
|
":proto_text",
|
||||||
":protos_all_cc",
|
":protos_all_cc",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/placer.h"
|
#include "tensorflow/core/common_runtime/placer.h"
|
||||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
|
||||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||||
#include "tensorflow/core/framework/graph_def_util.h"
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
@ -721,8 +720,6 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
|||||||
CHECK_EQ(options.callable_options.fetch_size(),
|
CHECK_EQ(options.callable_options.fetch_size(),
|
||||||
rewrite_metadata.fetch_types.size());
|
rewrite_metadata.fetch_types.size());
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(CheckXlaJitOptimizerOptions(session_options_));
|
|
||||||
|
|
||||||
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
|
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
|
||||||
GraphOptimizationPassOptions optimization_options;
|
GraphOptimizationPassOptions optimization_options;
|
||||||
optimization_options.session_options = session_options_;
|
optimization_options.session_options = session_options_;
|
||||||
|
@ -1,82 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/tf_xla_stub.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
bool is_xla_gpu_jit_registered = false;
|
|
||||||
bool is_xla_cpu_jit_registered = false;
|
|
||||||
|
|
||||||
struct XlaEnvVars {
|
|
||||||
bool xla_flags_env_var_present;
|
|
||||||
bool tf_xla_flags_env_var_present;
|
|
||||||
};
|
|
||||||
|
|
||||||
XlaEnvVars ComputeEnvVarHasXlaFlags() {
|
|
||||||
XlaEnvVars env_vars;
|
|
||||||
env_vars.xla_flags_env_var_present = getenv("XLA_FLAGS") != nullptr;
|
|
||||||
env_vars.tf_xla_flags_env_var_present = getenv("TF_XLA_FLAGS") != nullptr;
|
|
||||||
return env_vars;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
XlaGpuJitIsLinkedIn::XlaGpuJitIsLinkedIn() { is_xla_gpu_jit_registered = true; }
|
|
||||||
XlaCpuJitIsLinkedIn::XlaCpuJitIsLinkedIn() { is_xla_cpu_jit_registered = true; }
|
|
||||||
|
|
||||||
Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options) {
|
|
||||||
static XlaEnvVars env_vars = ComputeEnvVarHasXlaFlags();
|
|
||||||
|
|
||||||
if (is_xla_cpu_jit_registered || is_xla_gpu_jit_registered) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (env_vars.xla_flags_env_var_present) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"The XLA JIT is not linked in but the \"XLA_FLAGS\" environment "
|
|
||||||
"variable is set. Please either link in XLA or remove \"XLA_FLAGS\" "
|
|
||||||
"from the environment.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (env_vars.tf_xla_flags_env_var_present) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"The XLA JIT is not linked in but the \"TF_XLA_FLAGS\" environment "
|
|
||||||
"variable is set. Please either link in XLA or remove "
|
|
||||||
"\"TF_XLA_FLAGS\" from the environment.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (session_options) {
|
|
||||||
OptimizerOptions::GlobalJitLevel jit_level =
|
|
||||||
session_options->config.graph_options()
|
|
||||||
.optimizer_options()
|
|
||||||
.global_jit_level();
|
|
||||||
|
|
||||||
if (jit_level == OptimizerOptions::ON_1 ||
|
|
||||||
jit_level == OptimizerOptions::ON_2) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"The XLA JIT is enabled in the session options but XLA is not linked "
|
|
||||||
"in. Plesae either link in XLA or disable the JIT in the session "
|
|
||||||
"options.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace tensorflow
|
|
@ -1,50 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/public/session_options.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
// Returns an error if the XLA JIT is enabled via `session_options` or if the
|
|
||||||
// TF_XLA_FLAGS or XLA_FLAGS environment variables are set, but neither the
|
|
||||||
// XLA CPU JIT nor the XLA GPU JIT are linked in.
|
|
||||||
//
|
|
||||||
// If `session_options` is null then only the environment variables are checked.
|
|
||||||
Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options);
|
|
||||||
|
|
||||||
// The XLA CPU JIT creates a static instance of this class to notify
|
|
||||||
// `CheckXlaJitOptimizerOptions` that the XLA CPU JIT is linked in.
|
|
||||||
//
|
|
||||||
// NB! The constructor of this class (if run at all) needs to be ordered (via
|
|
||||||
// happens before) before any call to `CheckXlaJitOptimizerOptions`.
|
|
||||||
class XlaCpuJitIsLinkedIn {
|
|
||||||
public:
|
|
||||||
XlaCpuJitIsLinkedIn();
|
|
||||||
};
|
|
||||||
|
|
||||||
// The XLA GPU JIT creates a static instance of this class to notify
|
|
||||||
// `CheckXlaJitOptimizerOptions` that the XLA GPU JIT is linked in.
|
|
||||||
//
|
|
||||||
// NB! The constructor of this class (if run at all) needs to be ordered (via
|
|
||||||
// happens before) before any call to `CheckXlaJitOptimizerOptions`.
|
|
||||||
class XlaGpuJitIsLinkedIn {
|
|
||||||
public:
|
|
||||||
XlaGpuJitIsLinkedIn();
|
|
||||||
};
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_
|
|
Loading…
Reference in New Issue
Block a user