Rollback of rollback of disabling XLA:CPU/GPU devices by default.
PiperOrigin-RevId: 326084520 Change-Id: Id537bc29ac9d4c4c3c8fe7533e8d4151adb4cadf
This commit is contained in:
parent
d9d7f37118
commit
68016e2697
RELEASE.md
tensorflow
compiler
jit
BUILDflags.cc
kernels
mark_for_compilation_pass_test.ccpartially_decluster_pass_test.ccxla_compile_on_demand_op.ccxla_device.ccxla_device.hxla_ops_on_regular_devices.ccxla_platform_info.ccxla_platform_info.htests
tf2xla
xrt
core/grappler/optimizers
python
@ -33,6 +33,9 @@
|
||||
shape assumptions (note that you can pass shapes with `None` entries for axes
|
||||
that are meant to be dynamic). You can also disable the input checking
|
||||
entirely by setting `model.input_spec = None`.
|
||||
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
||||
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
||||
removed).
|
||||
|
||||
## Known Caveats
|
||||
|
||||
|
@ -206,6 +206,7 @@ cc_library(
|
||||
"xla_device.cc",
|
||||
"xla_device_context.cc",
|
||||
"xla_device_ops.cc",
|
||||
"xla_ops_on_regular_devices.cc",
|
||||
"xla_platform_info.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
@ -159,7 +159,7 @@ void AllocateAndParseFlags() {
|
||||
|
||||
device_flags = new XlaDeviceFlags;
|
||||
device_flags->tf_xla_compile_on_demand = false;
|
||||
device_flags->tf_xla_enable_xla_devices = true;
|
||||
device_flags->tf_xla_enable_xla_devices = false;
|
||||
|
||||
ops_flags = new XlaOpsCommonFlags;
|
||||
ops_flags->tf_xla_always_defer_compilation = false;
|
||||
|
@ -191,7 +191,7 @@ static Status CompileToLocalExecutable(
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
|
||||
std::map<int, Tensor> constant_args;
|
||||
for (int i : constants) {
|
||||
|
@ -44,6 +44,11 @@ using ::tensorflow::testing::FindNodeByName;
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static bool Initialized = [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
return true;
|
||||
}();
|
||||
|
||||
REGISTER_OP("UncompilableNullary").Output("o: float");
|
||||
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
|
||||
|
||||
|
@ -406,37 +406,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output dynamic_slice_operand =
|
||||
ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output dynamic_slice_begin = ops::Placeholder(
|
||||
s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{});
|
||||
Output dynamic_slice_size = ops::Placeholder(
|
||||
s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{});
|
||||
Output dynamic_slice =
|
||||
ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand,
|
||||
dynamic_slice_begin, dynamic_slice_size);
|
||||
|
||||
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
|
||||
DT_FLOAT, ops::Placeholder::Attrs{});
|
||||
Output reshape =
|
||||
ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice);
|
||||
|
||||
AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0");
|
||||
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
||||
|
||||
Node* n = FindNodeByName(*graph, "dynamic_slice");
|
||||
ASSERT_NE(n, nullptr);
|
||||
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
|
||||
const char* const kClusteredProducer0Name = "ClusteredProducer0";
|
||||
const char* const kClusteredProducer1Name = "ClusteredProducer1";
|
||||
|
@ -48,9 +48,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
const ResourceVarsSnapshot& variable_args) {
|
||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, client->backend().memory_allocator(),
|
||||
client->default_device_ordinal(),
|
||||
client, allocator, client->default_device_ordinal(),
|
||||
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
|
||||
platform_info_.xla_device_metadata()
|
||||
? platform_info_.xla_device_metadata()->UseMultipleStreams()
|
||||
@ -76,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
VLOG(2) << "Executing computation: " << name();
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(client->backend().memory_allocator());
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
|
||||
@ -108,6 +110,7 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& device_tensor = ctx->input(i);
|
||||
|
||||
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
|
||||
if (xla_tensor->has_host_tensor()) {
|
||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
||||
@ -118,24 +121,30 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
|
||||
if (!constant_arguments.count(i)) {
|
||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
||||
// Slow path; the argument is not available as a host constant so we
|
||||
// must fetch it synchronously.
|
||||
Tensor host_tensor;
|
||||
AllocatorAttributes attrs;
|
||||
attrs.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
|
||||
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
|
||||
&device_tensor, "ConstantArgument",
|
||||
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Copying tensor of shape "
|
||||
<< device_tensor.shape().DebugString() << " from "
|
||||
<< ctx->device()->name() << "to CPU failed with "
|
||||
<< status.ToString();
|
||||
return status;
|
||||
if (ctx->input_memory_type(i) != HOST_MEMORY &&
|
||||
ctx->op_device_context()) {
|
||||
// Slow path; the argument is not available as a host constant so we
|
||||
// must fetch it synchronously.
|
||||
Tensor host_tensor;
|
||||
AllocatorAttributes attrs;
|
||||
attrs.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
|
||||
device_tensor.shape(),
|
||||
&host_tensor, attrs));
|
||||
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
|
||||
&device_tensor, "ConstantArgument",
|
||||
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Copying tensor of shape "
|
||||
<< device_tensor.shape().DebugString() << " from "
|
||||
<< ctx->device()->name() << "to CPU failed with "
|
||||
<< status.ToString();
|
||||
return status;
|
||||
}
|
||||
constant_arguments[i] = host_tensor;
|
||||
} else {
|
||||
constant_arguments[i] = device_tensor;
|
||||
}
|
||||
constant_arguments[i] = host_tensor;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -153,7 +162,7 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptions(*cache, ctx, platform_info_,
|
||||
GenerateCompilerOptions(**cache, ctx, platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
@ -184,6 +193,8 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
ResourceVarsSnapshot variable_args;
|
||||
XlaCompilationCache* cache;
|
||||
OP_REQUIRES(ctx, ctx->function_library(),
|
||||
errors::Internal("Function library missing"));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
Compile(ctx, &result, &cache, &variable_args, &executable));
|
||||
|
||||
|
@ -61,6 +61,21 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Default PaddedShapeFn implementation that simply returns the unpadded
|
||||
// on-device shape. This is accurate for CPU and GPU devices that neither
|
||||
// transpose nor pad tensors.
|
||||
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
||||
const tensorflow::XlaTensor* xla_tensor =
|
||||
tensorflow::XlaTensor::FromTensor(&tensor);
|
||||
if (xla_tensor == nullptr) {
|
||||
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
|
||||
}
|
||||
|
||||
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
|
||||
*shape = shaped_buffer.on_device_shape();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
|
||||
// XlaDeviceAllocator is created on demand and is associated with a
|
||||
// XlaDevice. It outlives the device itself (for instance, the buffer
|
||||
@ -116,20 +131,6 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
|
||||
|
||||
namespace {
|
||||
|
||||
// Default PaddedShapeFn implementation that simply returns the unpadded
|
||||
// on-device shape. This is accurate for CPU and GPU devices that neither
|
||||
// transpose nor pad tensors.
|
||||
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
||||
const tensorflow::XlaTensor* xla_tensor =
|
||||
tensorflow::XlaTensor::FromTensor(&tensor);
|
||||
if (xla_tensor == nullptr) {
|
||||
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
|
||||
}
|
||||
|
||||
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
|
||||
*shape = shaped_buffer.on_device_shape();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
|
||||
const string& device_name,
|
||||
|
@ -280,6 +280,8 @@ struct XlaDeviceOpRegistrations {
|
||||
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
const char* jit_device);
|
||||
|
||||
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
||||
|
89
tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
Normal file
89
tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
Normal file
@ -0,0 +1,89 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Register XlaXXX operations on regular CPU/GPU devices using
|
||||
// `XlaCompileOnDemandOp`.
|
||||
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define REGISTER_XLA_OPS_ON_DEVICE(DEVICE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaConv") \
|
||||
.HostMemory("window_strides") \
|
||||
.HostMemory("padding") \
|
||||
.HostMemory("lhs_dilation") \
|
||||
.HostMemory("rhs_dilation") \
|
||||
.HostMemory("feature_group_count") \
|
||||
.Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaSelfAdjointEig").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaSvd").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaDynamicSlice").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaDynamicUpdateSlice").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaIf").Device(DEVICE), XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaPad").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaRecv").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaReduce").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaSelectAndScatter") \
|
||||
.HostMemory("window_dimensions") \
|
||||
.HostMemory("window_strides") \
|
||||
.HostMemory("padding") \
|
||||
.Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaSend").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaSort").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaReplicaId").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaGather") \
|
||||
.HostMemory("start_indices") \
|
||||
.HostMemory("slice_sizes") \
|
||||
.Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaScatter").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp);
|
||||
|
||||
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_CPU);
|
||||
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_GPU);
|
||||
|
||||
} // namespace tensorflow
|
@ -128,16 +128,17 @@ se::DeviceMemoryAllocator* GetAllocator(
|
||||
}
|
||||
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
XlaCompilationCache* cache, OpKernelContext* ctx,
|
||||
const XlaCompilationCache& cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
|
||||
CHECK(ctx->function_library());
|
||||
XlaCompiler::Options options;
|
||||
options.client = static_cast<xla::LocalClient*>(cache->client());
|
||||
options.client = static_cast<xla::LocalClient*>(cache.client());
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
options.device_ordinal =
|
||||
ctx->op_device_context()->stream()->parent()->device_ordinal();
|
||||
}
|
||||
options.device_type = cache->device_type();
|
||||
options.device_type = cache.device_type();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
|
@ -99,7 +99,7 @@ se::DeviceMemoryAllocator* GetAllocator(
|
||||
// Returns created options for the XLA compiler, and writes the used allocator
|
||||
// into `tf_allocator_adapter`.
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
XlaCompilationCache* cache, OpKernelContext* ctx,
|
||||
const XlaCompilationCache& cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
|
||||
|
||||
|
@ -1687,6 +1687,7 @@ tf_cuda_cc_test(
|
||||
deps = [
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_kernel_creator",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -43,6 +44,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static bool Initialized = [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
return true;
|
||||
}();
|
||||
|
||||
class UnaryOpsCompositionTest : public OpsTestBase {
|
||||
protected:
|
||||
template <typename T>
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -27,6 +28,10 @@ from tensorflow.python.platform import test
|
||||
|
||||
class XlaDeviceGpuTest(test.TestCase):
|
||||
|
||||
def __init__(self, method_name="runTest"):
|
||||
super(XlaDeviceGpuTest, self).__init__(method_name)
|
||||
context.context().enable_xla_devices()
|
||||
|
||||
def testCopiesToAndFromGpuWork(self):
|
||||
"""Tests that copies between GPU and XLA devices work."""
|
||||
if not test.is_gpu_available():
|
||||
|
@ -83,6 +83,8 @@ class XLATestCase(test.TestCase):
|
||||
|
||||
def __init__(self, method_name='runTest'):
|
||||
super(XLATestCase, self).__init__(method_name)
|
||||
if 'XLA' in FLAGS.test_device:
|
||||
context.context().enable_xla_devices()
|
||||
context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled()
|
||||
|
||||
self.device = FLAGS.test_device
|
||||
|
@ -787,6 +787,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -1087,6 +1088,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -217,5 +218,10 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_1) {
|
||||
EXPECT_EQ(const_args, std::vector<bool>({true}));
|
||||
}
|
||||
|
||||
static bool Initialized = [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -139,5 +140,11 @@ TEST(FusedBatchnormReserveSpaceTest, Test) {
|
||||
test::ExpectClose(results[0], results[1], /*atol=*/1e-4);
|
||||
test::ExpectClose(results[2], results[3], /*atol=*/1e-4);
|
||||
}
|
||||
|
||||
static bool Initialized = [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -96,6 +96,7 @@ tf_gen_op_libs(
|
||||
"xrt_execute_op",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
@ -20,6 +21,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static bool Initialized = [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
return true;
|
||||
}();
|
||||
|
||||
REGISTER_OP("XRTAllocate")
|
||||
.Input("allocation: string")
|
||||
.Output("handle: int64")
|
||||
|
@ -44,26 +44,6 @@ TEST_F(PinToHostOptimizerTest, TryFindHostDeviceCpuXlaGpu) {
|
||||
"/device:CPU:0");
|
||||
}
|
||||
|
||||
TEST_F(PinToHostOptimizerTest, TryFindHostDeviceXlaCpuXlaGpu) {
|
||||
gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
|
||||
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
|
||||
"/device:XLA_CPU:0");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
|
||||
"/device:XLA_CPU:0");
|
||||
}
|
||||
|
||||
TEST_F(PinToHostOptimizerTest, TryFindHostDeviceXlaGpu) {
|
||||
gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
|
||||
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
|
||||
"");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
|
||||
"");
|
||||
}
|
||||
|
||||
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import is_in_graph_mode
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
GRAPH_MODE = 0
|
||||
@ -1254,12 +1255,7 @@ class Context(object):
|
||||
p: i for i, p in enumerate(self._physical_devices)
|
||||
}
|
||||
|
||||
# Construct the visible device list from all physical devices but ignore
|
||||
# XLA devices
|
||||
self._visible_device_list = [
|
||||
d for d in self._physical_devices
|
||||
if not d.device_type.startswith("XLA")
|
||||
]
|
||||
self._visible_device_list = list(self._physical_devices)
|
||||
self._memory_growth_map = {
|
||||
d: None for d in self._physical_devices if d.device_type == "GPU"
|
||||
}
|
||||
@ -1493,6 +1489,12 @@ class Context(object):
|
||||
|
||||
self._virtual_device_map[dev] = virtual_devices
|
||||
|
||||
@deprecated(
|
||||
None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
|
||||
def enable_xla_devices(self):
|
||||
"""Enables XLA:CPU and XLA:GPU devices registration."""
|
||||
pywrap_tfe.TF_EnableXlaDevices()
|
||||
|
||||
@property
|
||||
def enable_mlir_bridge(self):
|
||||
return pywrap_tfe.TF_IsMlirBridgeEnabled()
|
||||
|
@ -435,9 +435,6 @@ class DeviceTest(test.TestCase):
|
||||
self.assertEqual(len(config.get_visible_devices('CPU')), 1)
|
||||
self.assertGreater(len(config.get_visible_devices('GPU')), 0)
|
||||
|
||||
# get_visible_devices filters out XLA_* devices. list_logical_devices does
|
||||
# not, but we can't call it here because it initializes the devices and
|
||||
# calling set_visible_devices after that is disallowed.
|
||||
self.assertEqual(len(config.get_visible_devices('XLA_GPU')), 0)
|
||||
|
||||
config.set_visible_devices(cpus[0])
|
||||
@ -451,12 +448,6 @@ class DeviceTest(test.TestCase):
|
||||
a = array_ops.identity(1.0)
|
||||
self.evaluate(a)
|
||||
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'Could not satisfy'):
|
||||
with ops.device('/device:XLA_GPU:0'):
|
||||
a = array_ops.identity(1.0)
|
||||
self.evaluate(a)
|
||||
|
||||
# Modifying the visible devices is not supported
|
||||
with self.assertRaisesRegex(RuntimeError, 'cannot be modified'):
|
||||
config.set_visible_devices(gpus)
|
||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
from tensorflow.compiler.tf2xla.python import xla as xla_ops
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -39,6 +40,10 @@ from tensorflow.python.platform import test
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PForTest(PForTestCase):
|
||||
|
||||
def __init__(self, method_name="runTest"):
|
||||
super(PForTest, self).__init__(method_name)
|
||||
context.context().enable_xla_devices()
|
||||
|
||||
def test_xla_einsum(self):
|
||||
num_loop = 10
|
||||
x_series = random_ops.random_uniform([num_loop, 9, 9])
|
||||
|
@ -444,6 +444,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m.def("TF_EnableMlirBridge", [](bool enabled) {
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
|
||||
});
|
||||
m.def("TF_EnableXlaDevices", [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
});
|
||||
|
||||
// // TFE_Context Logic
|
||||
m.def(
|
||||
|
Loading…
Reference in New Issue
Block a user