Rollback of rollback of disabling XLA:CPU/GPU devices by default.

PiperOrigin-RevId: 326084520
Change-Id: Id537bc29ac9d4c4c3c8fe7533e8d4151adb4cadf
This commit is contained in:
George Karpenkov 2020-08-11 13:02:39 -07:00 committed by TensorFlower Gardener
parent d9d7f37118
commit 68016e2697
26 changed files with 206 additions and 107 deletions

View File

@ -33,6 +33,9 @@
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
## Known Caveats

View File

@ -206,6 +206,7 @@ cc_library(
"xla_device.cc",
"xla_device_context.cc",
"xla_device_ops.cc",
"xla_ops_on_regular_devices.cc",
"xla_platform_info.cc",
],
hdrs = [

View File

@ -159,7 +159,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
device_flags->tf_xla_enable_xla_devices = false;
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;

View File

@ -191,7 +191,7 @@ static Status CompileToLocalExecutable(
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options = GenerateCompilerOptions(
cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
std::map<int, Tensor> constant_args;
for (int i : constants) {

View File

@ -44,6 +44,11 @@ using ::tensorflow::testing::FindNodeByName;
namespace tensorflow {
namespace {
static bool Initialized = [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
return true;
}();
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");

View File

@ -406,37 +406,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
}
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output dynamic_slice_operand =
ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32,
ops::Placeholder::Attrs{});
Output dynamic_slice_begin = ops::Placeholder(
s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{});
Output dynamic_slice_size = ops::Placeholder(
s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{});
Output dynamic_slice =
ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand,
dynamic_slice_begin, dynamic_slice_size);
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
DT_FLOAT, ops::Placeholder::Attrs{});
Output reshape =
ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice);
AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0");
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(graph.get()));
Node* n = FindNodeByName(*graph, "dynamic_slice");
ASSERT_NE(n, nullptr);
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
}
TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
const char* const kClusteredProducer0Name = "ClusteredProducer0";
const char* const kClusteredProducer1Name = "ClusteredProducer1";

View File

@ -48,9 +48,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(),
client->default_device_ordinal(),
client, allocator, client->default_device_ordinal(),
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
platform_info_.xla_device_metadata()
? platform_info_.xla_device_metadata()->UseMultipleStreams()
@ -76,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
VLOG(2) << "Executing computation: " << name();
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
@ -108,6 +110,7 @@ Status XlaCompileOnDemandOp::Compile(
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
if (xla_tensor->has_host_tensor()) {
if (absl::c_binary_search(constant_input_indices, i)) {
@ -118,24 +121,30 @@ Status XlaCompileOnDemandOp::Compile(
if (!constant_arguments.count(i)) {
if (absl::c_binary_search(constant_input_indices, i)) {
// Slow path; the argument is not available as a host constant so we
// must fetch it synchronously.
Tensor host_tensor;
AllocatorAttributes attrs;
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "
<< ctx->device()->name() << "to CPU failed with "
<< status.ToString();
return status;
if (ctx->input_memory_type(i) != HOST_MEMORY &&
ctx->op_device_context()) {
// Slow path; the argument is not available as a host constant so we
// must fetch it synchronously.
Tensor host_tensor;
AllocatorAttributes attrs;
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
device_tensor.shape(),
&host_tensor, attrs));
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "
<< ctx->device()->name() << "to CPU failed with "
<< status.ToString();
return status;
}
constant_arguments[i] = host_tensor;
} else {
constant_arguments[i] = device_tensor;
}
constant_arguments[i] = host_tensor;
}
}
}
@ -153,7 +162,7 @@ Status XlaCompileOnDemandOp::Compile(
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(*cache, ctx, platform_info_,
GenerateCompilerOptions(**cache, ctx, platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::CompileOptions compile_options;
@ -184,6 +193,8 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
ResourceVarsSnapshot variable_args;
XlaCompilationCache* cache;
OP_REQUIRES(ctx, ctx->function_library(),
errors::Internal("Function library missing"));
OP_REQUIRES_OK(ctx,
Compile(ctx, &result, &cache, &variable_args, &executable));

View File

@ -61,6 +61,21 @@ limitations under the License.
namespace tensorflow {
// Default PaddedShapeFn implementation that simply returns the unpadded
// on-device shape. This is accurate for CPU and GPU devices that neither
// transpose nor pad tensors.
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const tensorflow::XlaTensor* xla_tensor =
tensorflow::XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
}
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
*shape = shaped_buffer.on_device_shape();
return Status::OK();
}
// Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
// XlaDeviceAllocator is created on demand and is associated with a
// XlaDevice. It outlives the device itself (for instance, the buffer
@ -116,20 +131,6 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
namespace {
// Default PaddedShapeFn implementation that simply returns the unpadded
// on-device shape. This is accurate for CPU and GPU devices that neither
// transpose nor pad tensors.
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const tensorflow::XlaTensor* xla_tensor =
tensorflow::XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
}
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
*shape = shaped_buffer.on_device_shape();
return Status::OK();
}
static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
const string& device_name,

View File

@ -280,6 +280,8 @@ struct XlaDeviceOpRegistrations {
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device);
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Register XlaXXX operations on regular CPU/GPU devices using
// `XlaCompileOnDemandOp`.
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
#define REGISTER_XLA_OPS_ON_DEVICE(DEVICE) \
REGISTER_KERNEL_BUILDER(Name("XlaConv") \
.HostMemory("window_strides") \
.HostMemory("padding") \
.HostMemory("lhs_dilation") \
.HostMemory("rhs_dilation") \
.HostMemory("feature_group_count") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER( \
Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSelfAdjointEig").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSvd").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDynamicSlice").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDynamicUpdateSlice").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaIf").Device(DEVICE), XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaPad").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaRecv").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReduce").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSelectAndScatter") \
.HostMemory("window_dimensions") \
.HostMemory("window_strides") \
.HostMemory("padding") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSend").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSort").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReplicaId").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaGather") \
.HostMemory("start_indices") \
.HostMemory("slice_sizes") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaScatter").Device(DEVICE), \
XlaCompileOnDemandOp);
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_CPU);
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_GPU);
} // namespace tensorflow

View File

@ -128,16 +128,17 @@ se::DeviceMemoryAllocator* GetAllocator(
}
XlaCompiler::Options GenerateCompilerOptions(
XlaCompilationCache* cache, OpKernelContext* ctx,
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
CHECK(ctx->function_library());
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache->client());
options.client = static_cast<xla::LocalClient*>(cache.client());
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache->device_type();
options.device_type = cache.device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =

View File

@ -99,7 +99,7 @@ se::DeviceMemoryAllocator* GetAllocator(
// Returns created options for the XLA compiler, and writes the used allocator
// into `tf_allocator_adapter`.
XlaCompiler::Options GenerateCompilerOptions(
XlaCompilationCache* cache, OpKernelContext* ctx,
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);

View File

@ -1687,6 +1687,7 @@ tf_cuda_cc_test(
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_kernel_creator",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/synchronization/notification.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
@ -43,6 +44,11 @@ limitations under the License.
namespace tensorflow {
namespace {
static bool Initialized = [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
return true;
}();
class UnaryOpsCompositionTest : public OpsTestBase {
protected:
template <typename T>

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -27,6 +28,10 @@ from tensorflow.python.platform import test
class XlaDeviceGpuTest(test.TestCase):
def __init__(self, method_name="runTest"):
super(XlaDeviceGpuTest, self).__init__(method_name)
context.context().enable_xla_devices()
def testCopiesToAndFromGpuWork(self):
"""Tests that copies between GPU and XLA devices work."""
if not test.is_gpu_available():

View File

@ -83,6 +83,8 @@ class XLATestCase(test.TestCase):
def __init__(self, method_name='runTest'):
super(XLATestCase, self).__init__(method_name)
if 'XLA' in FLAGS.test_device:
context.context().enable_xla_devices()
context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled()
self.device = FLAGS.test_device

View File

@ -787,6 +787,7 @@ tf_cc_test(
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_internal",
@ -1087,6 +1088,7 @@ tf_cuda_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/graph/algorithm.h"
@ -217,5 +218,10 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_1) {
EXPECT_EQ(const_args, std::vector<bool>({true}));
}
static bool Initialized = [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
return true;
}();
} // namespace
} // namespace tensorflow

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
@ -139,5 +140,11 @@ TEST(FusedBatchnormReserveSpaceTest, Test) {
test::ExpectClose(results[0], results[1], /*atol=*/1e-4);
test::ExpectClose(results[2], results[3], /*atol=*/1e-4);
}
static bool Initialized = [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
return true;
}();
} // namespace
} // namespace tensorflow

View File

@ -96,6 +96,7 @@ tf_gen_op_libs(
"xrt_execute_op",
],
deps = [
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:lib",
],
)

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@ -20,6 +21,11 @@ limitations under the License.
namespace tensorflow {
static bool Initialized = [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
return true;
}();
REGISTER_OP("XRTAllocate")
.Input("allocation: string")
.Output("handle: int64")

View File

@ -44,26 +44,6 @@ TEST_F(PinToHostOptimizerTest, TryFindHostDeviceCpuXlaGpu) {
"/device:CPU:0");
}
TEST_F(PinToHostOptimizerTest, TryFindHostDeviceXlaCpuXlaGpu) {
gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
"/device:XLA_CPU:0");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
"/device:XLA_CPU:0");
}
TEST_F(PinToHostOptimizerTest, TryFindHostDeviceXlaGpu) {
gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
"");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
"");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});

View File

@ -42,6 +42,7 @@ from tensorflow.python.framework import device as pydev
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
GRAPH_MODE = 0
@ -1254,12 +1255,7 @@ class Context(object):
p: i for i, p in enumerate(self._physical_devices)
}
# Construct the visible device list from all physical devices but ignore
# XLA devices
self._visible_device_list = [
d for d in self._physical_devices
if not d.device_type.startswith("XLA")
]
self._visible_device_list = list(self._physical_devices)
self._memory_growth_map = {
d: None for d in self._physical_devices if d.device_type == "GPU"
}
@ -1493,6 +1489,12 @@ class Context(object):
self._virtual_device_map[dev] = virtual_devices
@deprecated(
None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
def enable_xla_devices(self):
"""Enables XLA:CPU and XLA:GPU devices registration."""
pywrap_tfe.TF_EnableXlaDevices()
@property
def enable_mlir_bridge(self):
return pywrap_tfe.TF_IsMlirBridgeEnabled()

View File

@ -435,9 +435,6 @@ class DeviceTest(test.TestCase):
self.assertEqual(len(config.get_visible_devices('CPU')), 1)
self.assertGreater(len(config.get_visible_devices('GPU')), 0)
# get_visible_devices filters out XLA_* devices. list_logical_devices does
# not, but we can't call it here because it initializes the devices and
# calling set_visible_devices after that is disallowed.
self.assertEqual(len(config.get_visible_devices('XLA_GPU')), 0)
config.set_visible_devices(cpus[0])
@ -451,12 +448,6 @@ class DeviceTest(test.TestCase):
a = array_ops.identity(1.0)
self.evaluate(a)
with self.assertRaisesRegex(errors.InvalidArgumentError,
'Could not satisfy'):
with ops.device('/device:XLA_GPU:0'):
a = array_ops.identity(1.0)
self.evaluate(a)
# Modifying the visible devices is not supported
with self.assertRaisesRegex(RuntimeError, 'cannot be modified'):
config.set_visible_devices(gpus)

View File

@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.compiler.tf2xla.python import xla as xla_ops
from tensorflow.python.compiler.xla import jit
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
@ -39,6 +40,10 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class PForTest(PForTestCase):
def __init__(self, method_name="runTest"):
super(PForTest, self).__init__(method_name)
context.context().enable_xla_devices()
def test_xla_einsum(self):
num_loop = 10
x_series = random_ops.random_uniform([num_loop, 9, 9])

View File

@ -444,6 +444,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TF_EnableMlirBridge", [](bool enabled) {
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
});
m.def("TF_EnableXlaDevices", [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
});
// // TFE_Context Logic
m.def(