[TPU] Move TPU node and system device initializers to compiler/jit
This is part of a series of changes to move TPU-related code to better locations so that the TensorFlow build isn't confused and TPU-based TF can be built without the define=framework_shared_object=false flag. PiperOrigin-RevId: 352726495 Change-Id: Idc23455a8289c4a2546edad9ca59e9207a7492ce
This commit is contained in:
parent
0fa8e3c7dd
commit
c5f474d1c8
tensorflow
@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_copts")
|
||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
@ -67,7 +67,7 @@ cc_library(
|
||||
] + if_cuda_or_rocm([
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
]),
|
||||
]) + if_with_tpu_support([":xla_tpu_device"]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -153,6 +153,42 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_tpu_device",
|
||||
srcs = ["xla_tpu_device.cc"],
|
||||
hdrs = ["xla_tpu_device.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_device",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core/common_runtime:copy_tensor",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/common_runtime:dma_helper",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu:tpu_node_device_util",
|
||||
"//tensorflow/core/tpu:virtual_device",
|
||||
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_base",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_stream_interface",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_tensor",
|
||||
srcs = ["xla_tensor.cc"],
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_node_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_tpu_device.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
@ -32,9 +32,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_node_device_util.h"
|
||||
#include "tensorflow/core/tpu/virtual_device.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
|
||||
|
||||
@ -314,7 +316,7 @@ Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
int device_count = platform->VisibleDeviceCount();
|
||||
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
const string device_name = strings::StrCat("/physical_device:TPU:", i);
|
||||
const string device_name = absl::StrCat("/physical_device:TPU:", i);
|
||||
devices->push_back(device_name);
|
||||
}
|
||||
|
||||
@ -393,6 +395,52 @@ Status TpuNodeDeviceFactory::CreateDevices(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class TpuSystemDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status ListPhysicalDevices(std::vector<string>* devices) override;
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||
};
|
||||
|
||||
Status TpuSystemDeviceFactory::ListPhysicalDevices(
|
||||
std::vector<string>* devices) {
|
||||
int device_count = 0;
|
||||
TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
|
||||
if (device_count == 0) {
|
||||
VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
devices->push_back("/physical_device:TPU_SYSTEM:0");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuSystemDeviceFactory::CreateDevices(
|
||||
const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
int device_count = 0;
|
||||
TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
|
||||
if (device_count == 0) {
|
||||
VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 memory_limit;
|
||||
TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpuMemoryLimit(&memory_limit));
|
||||
|
||||
// Creates a device that represents a TPU distributed system.
|
||||
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
|
||||
absl::StrCat(name_prefix, "/device:", DEVICE_TPU_SYSTEM, ":", 0),
|
||||
DeviceType(DEVICE_TPU_SYSTEM), Bytes(memory_limit), DeviceLocality(),
|
||||
absl::StrCat("device: ", DEVICE_TPU_SYSTEM, " device"));
|
||||
devices->push_back(absl::make_unique<VirtualDevice>(options.env, attrs));
|
||||
VLOG(1) << "Created TPU_SYSTEM device. This host has " << device_count
|
||||
<< " TPUs";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterTpuDeviceToDeviceCopy() {
|
||||
@ -410,12 +458,29 @@ void RegisterTpuNodeDevice(
|
||||
tpu_use_substreams_for_cross_tpu_device_transfers_flag =
|
||||
tpu_use_substreams_for_cross_tpu_device_transfers;
|
||||
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
|
||||
REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
|
||||
}
|
||||
|
||||
void RegisterTpuSystemDevice() {
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
|
||||
}
|
||||
|
||||
#if !defined(PLATFORM_GOOGLE)
|
||||
|
||||
// We automatically register this if we are building for open source. For
|
||||
// Google platforms, we initialize these devices in other places.
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
|
||||
REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
|
||||
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
} // namespace tensorflow
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_TPU_NODE_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_TPU_TPU_NODE_DEVICE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -29,6 +29,8 @@ void RegisterTpuNodeDevice(
|
||||
bool tpu_autoclustering, bool tpu_xla_device_failure_closes_chips,
|
||||
bool tpu_use_substreams_for_cross_tpu_device_transfers);
|
||||
|
||||
void RegisterTpuSystemDevice();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_NODE_DEVICE_H_
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_
|
@ -94,6 +94,7 @@ cc_library(
|
||||
name = "tpu_defs",
|
||||
srcs = ["tpu_defs.cc"],
|
||||
hdrs = ["tpu_defs.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/core:protos_all_cc"],
|
||||
)
|
||||
|
||||
@ -151,9 +152,7 @@ cc_library(
|
||||
":tpu_compilation_device",
|
||||
":tpu_executor_init_fns",
|
||||
":tpu_library_init_fns",
|
||||
":tpu_node_device",
|
||||
":tpu_ops_c_api_hdrs",
|
||||
":tpu_system_device",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
|
||||
@ -200,49 +199,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_node_device",
|
||||
srcs = ["tpu_node_device.cc"],
|
||||
hdrs = ["tpu_node_device.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_api",
|
||||
":tpu_defs",
|
||||
":tpu_node_device_util",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/stream_executor/tpu:c_api_conversions",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_stream_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_system_device",
|
||||
srcs = ["tpu_system_device.cc"],
|
||||
hdrs = ["tpu_system_device.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":virtual_device",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "virtual_device",
|
||||
srcs = ["virtual_device.cc"],
|
||||
@ -329,8 +285,6 @@ cc_library(
|
||||
deps = [
|
||||
":tpu_api_dlsym_initializer",
|
||||
":tpu_compilation_device",
|
||||
":tpu_node_device",
|
||||
":tpu_system_device",
|
||||
"//tensorflow/core/tpu:tpu_on_demand_compiler",
|
||||
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||
"//tensorflow/core/tpu/ops",
|
||||
|
@ -22,8 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h"
|
||||
#if !defined(PLATFORM_GOOGLE)
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_node_device.h"
|
||||
#include "tensorflow/core/tpu/tpu_system_device.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
#endif
|
||||
@ -55,11 +53,6 @@ Status InitializeTpuLibrary(void* library_handle) {
|
||||
(*initialize_fn)(/*init_library=*/true, /*argc=*/0, /*argv=*/nullptr);
|
||||
|
||||
RegisterTpuPlatform();
|
||||
RegisterTpuSystemDevice();
|
||||
RegisterTpuNodeDevice(
|
||||
/*tpu_autoclustering=*/false,
|
||||
/*tpu_xla_device_failure_closes_chips=*/true,
|
||||
/*tpu_use_substreams_for_cross_tpu_device_transfers=*/true);
|
||||
}
|
||||
|
||||
return s;
|
||||
|
@ -1,80 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/tpu/virtual_device.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace {
|
||||
|
||||
class TpuSystemDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status ListPhysicalDevices(std::vector<string>* devices) override;
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||
};
|
||||
|
||||
Status TpuSystemDeviceFactory::ListPhysicalDevices(
|
||||
std::vector<string>* devices) {
|
||||
int device_count = 0;
|
||||
TF_RETURN_IF_ERROR(TpuPlatform::TpusPerHost(&device_count));
|
||||
if (device_count == 0) {
|
||||
VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
devices->push_back("/physical_device:TPU_SYSTEM:0");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuSystemDeviceFactory::CreateDevices(
|
||||
const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
int device_count = 0;
|
||||
TF_RETURN_IF_ERROR(TpuPlatform::TpusPerHost(&device_count));
|
||||
if (device_count == 0) {
|
||||
VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 memory_limit;
|
||||
TF_RETURN_IF_ERROR(TpuPlatform::TpuMemoryLimit(&memory_limit));
|
||||
|
||||
// Creates a device that represents a Jellyfish distributed system.
|
||||
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
|
||||
strings::StrCat(name_prefix, "/device:", DEVICE_TPU_SYSTEM, ":", 0),
|
||||
DeviceType(DEVICE_TPU_SYSTEM), Bytes(memory_limit), DeviceLocality(),
|
||||
strings::StrCat("device: ", DEVICE_TPU_SYSTEM, " device"));
|
||||
devices->push_back(absl::make_unique<VirtualDevice>(options.env, attrs));
|
||||
VLOG(1) << "Created TPU_SYSTEM device. This host has " << device_count
|
||||
<< " TPUs";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterTpuSystemDevice() {
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
@ -1,27 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_TPU_SYSTEM_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_TPU_TPU_SYSTEM_DEVICE_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
void RegisterTpuSystemDevice();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_SYSTEM_DEVICE_H_
|
@ -5,6 +5,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
package(
|
||||
default_visibility = [
|
||||
"//learning/brain/experimental/dtensor:__subpackages__",
|
||||
"//tensorflow/compiler/jit:__subpackages__",
|
||||
"//tensorflow/compiler/xrt:__subpackages__",
|
||||
"//tensorflow/core/profiler/internal/tpu:__subpackages__",
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
|
@ -269,12 +269,19 @@ def if_nccl(if_true, if_false = []):
|
||||
})
|
||||
|
||||
def if_libtpu(if_true, if_false = []):
|
||||
"""Shorthand for select()ing whether to build support for using TPUs via libtpu.so"""
|
||||
"""Shorthand for select()ing whether to build backend support for TPUs when building libtpu.so"""
|
||||
return select({
|
||||
str(Label("//tensorflow:with_tpu_support")): if_true,
|
||||
"//conditions:default": if_false,
|
||||
})
|
||||
|
||||
def if_with_tpu_support(if_true, if_false = []):
|
||||
"""Shorthand for select()ing whether to build API support for TPUs when building TensorFlow"""
|
||||
return select({
|
||||
"//tensorflow:with_tpu_support": if_true,
|
||||
"//conditions:default": if_false,
|
||||
})
|
||||
|
||||
def if_registration_v2(if_true, if_false = []):
|
||||
return select({
|
||||
"//tensorflow:registration_v2": if_true,
|
||||
|
Loading…
Reference in New Issue
Block a user