[TF2XLA] Add a kill switch to disable XLA compilation
PiperOrigin-RevId: 328788854 Change-Id: Ifab260cb5d9ceda01f155b77354136787c69a084
This commit is contained in:
parent
5b4cc9b10a
commit
99e8952ee2
@ -329,6 +329,7 @@ cc_library(
|
||||
srcs = ["xla_compilation_cache.cc"],
|
||||
hdrs = ["xla_compilation_cache.h"],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
@ -361,8 +362,11 @@ tf_cc_test(
|
||||
"xla_compilation_cache_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_compilation_cache",
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
|
||||
@ -268,4 +268,10 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
}
|
||||
|
||||
static std::atomic<bool> xla_compilation_disabled(false);
|
||||
|
||||
void DisableXlaCompilation() { xla_compilation_disabled = true; }
|
||||
|
||||
bool FailOnXlaCompilation() { return xla_compilation_disabled; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -162,6 +162,13 @@ MlirCommonFlags* GetMlirCommonFlags();
|
||||
void AppendMarkForCompilationPassFlags(
|
||||
std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// Disables XLA compilation, forces it to return an error message instead. Can
|
||||
// be used by a server to ensure that JIT compilation is opt-in.
|
||||
void DisableXlaCompilation();
|
||||
|
||||
// Returns `false` unless `DisableXlaCompilation` was called.
|
||||
bool FailOnXlaCompilation();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||
|
||||
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
@ -323,6 +324,10 @@ Status XlaCompilationCache::CompileImpl(
|
||||
absl::optional<int64> compile_threshold,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
if (FailOnXlaCompilation()) {
|
||||
return errors::Internal("XLA compilation disabled");
|
||||
}
|
||||
|
||||
DCHECK_NE(out_executable, nullptr);
|
||||
VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
|
||||
|
||||
|
||||
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
@ -52,6 +54,30 @@ TEST(XlaCompilationCacheTest, SignatureEquality) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) {
|
||||
NameAttrList fn;
|
||||
fn.set_name("afunction");
|
||||
|
||||
DisableXlaCompilation();
|
||||
|
||||
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
|
||||
DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||
|
||||
const XlaCompiler::CompilationResult* compilation_result;
|
||||
xla::LocalExecutable* executable;
|
||||
|
||||
auto cache = new XlaCompilationCache(client, device_type);
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
Status status = cache->Compile(XlaCompiler::Options{}, fn, {},
|
||||
XlaCompiler::CompileOptions{},
|
||||
XlaCompilationCache::CompileMode::kStrict,
|
||||
&compilation_result, &executable);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(status.error_message(), "XLA compilation disabled"));
|
||||
}
|
||||
|
||||
static void BM_BuildSignature(int iters, int n_args) {
|
||||
NameAttrList fn;
|
||||
fn.set_name("afunction");
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user