diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 682c0f0cb05..bfc0b0070ee 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -45,15 +45,37 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "register_xla_cpu_jit", + srcs = ["register_xla_cpu_jit.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:tf_xla_stub", + ], + alwayslink = 1, +) + cc_library( name = "xla_cpu_jit", visibility = ["//visibility:public"], deps = [ ":jit_compilation_passes", + ":register_xla_cpu_jit", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:tf_xla_stub", + ], + alwayslink = 1, +) + +cc_library( + name = "register_xla_gpu_jit", + srcs = ["register_xla_gpu_jit.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:tf_xla_stub", ], alwayslink = 1, ) @@ -63,6 +85,8 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda([ ":jit_compilation_passes", + ":register_xla_gpu_jit", + "//tensorflow/core:tf_xla_stub", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", @@ -78,6 +102,7 @@ cc_library( deps = [ ":flags", ":jit_compilation_passes", + ":register_xla_cpu_jit", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -96,6 +121,7 @@ cc_library( visibility = [":friends"], deps = [ ":jit_compilation_passes", + ":register_xla_gpu_jit", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", diff --git a/tensorflow/compiler/jit/register_xla_cpu_jit.cc b/tensorflow/compiler/jit/register_xla_cpu_jit.cc new file mode 100644 index 00000000000..9cbef631271 --- /dev/null +++ b/tensorflow/compiler/jit/register_xla_cpu_jit.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/tf_xla_stub.h" + +namespace tensorflow { +namespace { +XlaCpuJitIsLinkedIn register_xla_cpu_jit; +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/register_xla_gpu_jit.cc b/tensorflow/compiler/jit/register_xla_gpu_jit.cc new file mode 100644 index 00000000000..7399a41d25f --- /dev/null +++ b/tensorflow/compiler/jit/register_xla_gpu_jit.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/tf_xla_stub.h" + +namespace tensorflow { +namespace { +XlaGpuJitIsLinkedIn register_xla_gpu_jit; +} +} // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 73e8db58a83..1162c47e6bf 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2976,6 +2976,19 @@ tf_cuda_library( ] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(), ) +tf_cuda_library( + name = "tf_xla_stub", + srcs = ["common_runtime/tf_xla_stub.cc"], + hdrs = ["common_runtime/tf_xla_stub.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":proto_text", + ":session_options", + ], +) + tf_cuda_library( name = "core_cpu_internal", srcs = [ @@ -2989,6 +3002,7 @@ tf_cuda_library( ":framework", ":graph", ":lib", + ":tf_xla_stub", ":proto_text", ":protos_all_cc", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 0d36930324a..c18268ad7bc 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/placer.h" +#include "tensorflow/core/common_runtime/tf_xla_stub.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -720,6 +721,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, CHECK_EQ(options.callable_options.fetch_size(), rewrite_metadata.fetch_types.size()); + TF_RETURN_IF_ERROR(CheckXlaJitOptimizerOptions(session_options_)); + // TODO(andydavis): Clarify optimization pass requirements around CostModel. GraphOptimizationPassOptions optimization_options; optimization_options.session_options = session_options_; diff --git a/tensorflow/core/common_runtime/tf_xla_stub.cc b/tensorflow/core/common_runtime/tf_xla_stub.cc new file mode 100644 index 00000000000..d463693669f --- /dev/null +++ b/tensorflow/core/common_runtime/tf_xla_stub.cc @@ -0,0 +1,82 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/tf_xla_stub.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +bool is_xla_gpu_jit_registered = false; +bool is_xla_cpu_jit_registered = false; + +struct XlaEnvVars { + bool xla_flags_env_var_present; + bool tf_xla_flags_env_var_present; +}; + +XlaEnvVars ComputeEnvVarHasXlaFlags() { + XlaEnvVars env_vars; + env_vars.xla_flags_env_var_present = getenv("XLA_FLAGS") != nullptr; + env_vars.tf_xla_flags_env_var_present = getenv("TF_XLA_FLAGS") != nullptr; + return env_vars; +} + +} // namespace + +XlaGpuJitIsLinkedIn::XlaGpuJitIsLinkedIn() { is_xla_gpu_jit_registered = true; } +XlaCpuJitIsLinkedIn::XlaCpuJitIsLinkedIn() { is_xla_cpu_jit_registered = true; } + +Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options) { + static XlaEnvVars env_vars = ComputeEnvVarHasXlaFlags(); + + if (is_xla_cpu_jit_registered || is_xla_gpu_jit_registered) { + return Status::OK(); + } + + if (env_vars.xla_flags_env_var_present) { + return errors::InvalidArgument( + "The XLA JIT is not linked in but the \"XLA_FLAGS\" environment " + "variable is set. Please either link in XLA or remove \"XLA_FLAGS\" " + "from the environment."); + } + + if (env_vars.tf_xla_flags_env_var_present) { + return errors::InvalidArgument( + "The XLA JIT is not linked in but the \"TF_XLA_FLAGS\" environment " + "variable is set. Please either link in XLA or remove " + "\"TF_XLA_FLAGS\" from the environment."); + } + + if (session_options) { + OptimizerOptions::GlobalJitLevel jit_level = + session_options->config.graph_options() + .optimizer_options() + .global_jit_level(); + + if (jit_level == OptimizerOptions::ON_1 || + jit_level == OptimizerOptions::ON_2) { + return errors::InvalidArgument( + "The XLA JIT is enabled in the session options but XLA is not linked " + "in. Plesae either link in XLA or disable the JIT in the session " + "options."); + } + } + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/tf_xla_stub.h b/tensorflow/core/common_runtime/tf_xla_stub.h new file mode 100644 index 00000000000..723b2b5cd2e --- /dev/null +++ b/tensorflow/core/common_runtime/tf_xla_stub.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +// Returns an error if the XLA JIT is enabled via `session_options` or if the +// TF_XLA_FLAGS or XLA_FLAGS environment variables are set, but neither the +// XLA CPU JIT nor the XLA GPU JIT are linked in. +// +// If `session_options` is null then only the environment variables are checked. +Status CheckXlaJitOptimizerOptions(const SessionOptions* session_options); + +// The XLA CPU JIT creates a static instance of this class to notify +// `CheckXlaJitOptimizerOptions` that the XLA CPU JIT is linked in. +// +// NB! The constructor of this class (if run at all) needs to be ordered (via +// happens before) before any call to `CheckXlaJitOptimizerOptions`. +class XlaCpuJitIsLinkedIn { + public: + XlaCpuJitIsLinkedIn(); +}; + +// The XLA GPU JIT creates a static instance of this class to notify +// `CheckXlaJitOptimizerOptions` that the XLA GPU JIT is linked in. +// +// NB! The constructor of this class (if run at all) needs to be ordered (via +// happens before) before any call to `CheckXlaJitOptimizerOptions`. +class XlaGpuJitIsLinkedIn { + public: + XlaGpuJitIsLinkedIn(); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TF_XLA_STUB_H_