[TF2XLA] Add a kill switch to disable XLA compilation

PiperOrigin-RevId: 328788854
Change-Id: Ifab260cb5d9ceda01f155b77354136787c69a084
This commit is contained in:
George Karpenkov 2020-08-27 12:16:09 -07:00 committed by TensorFlower Gardener
parent 5b4cc9b10a
commit 99e8952ee2
5 changed files with 48 additions and 0 deletions

View File

@ -329,6 +329,7 @@ cc_library(
srcs = ["xla_compilation_cache.cc"],
hdrs = ["xla_compilation_cache.h"],
deps = [
":flags",
":xla_activity_listener",
":xla_activity_proto_cc",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
@ -361,8 +362,11 @@ tf_cc_test(
"xla_compilation_cache_test.cc",
],
deps = [
":flags",
":xla_compilation_cache",
":xla_cpu_jit",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],

View File

@ -268,4 +268,10 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static std::atomic<bool> xla_compilation_disabled(false);
void DisableXlaCompilation() { xla_compilation_disabled = true; }
bool FailOnXlaCompilation() { return xla_compilation_disabled; }
} // namespace tensorflow

View File

@ -162,6 +162,13 @@ MlirCommonFlags* GetMlirCommonFlags();
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Disables XLA compilation, forces it to return an error message instead. Can
// be used by a server to ensure that JIT compilation is opt-in.
void DisableXlaCompilation();
// Returns `false` unless `DisableXlaCompilation` was called.
bool FailOnXlaCompilation();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/base/call_once.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
@ -323,6 +324,10 @@ Status XlaCompilationCache::CompileImpl(
absl::optional<int64> compile_threshold,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) {
if (FailOnXlaCompilation()) {
return errors::Internal("XLA compilation disabled");
}
DCHECK_NE(out_executable, nullptr);
VLOG(2) << "XlaCompilationCache::Compile " << DebugString();

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -52,6 +54,30 @@ TEST(XlaCompilationCacheTest, SignatureEquality) {
}
}
TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) {
NameAttrList fn;
fn.set_name("afunction");
DisableXlaCompilation();
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT);
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
auto cache = new XlaCompilationCache(client, device_type);
core::ScopedUnref cache_ref(cache);
Status status = cache->Compile(XlaCompiler::Options{}, fn, {},
XlaCompiler::CompileOptions{},
XlaCompilationCache::CompileMode::kStrict,
&compilation_result, &executable);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(
absl::StrContains(status.error_message(), "XLA compilation disabled"));
}
static void BM_BuildSignature(int iters, int n_args) {
NameAttrList fn;
fn.set_name("afunction");