Remove two unused LOG(FATAL) methods from TpuNodeContext.
PiperOrigin-RevId: 322711148 Change-Id: Id928e49c8227e4cbf0982443ed08dd46f767eaea
This commit is contained in:
parent
663cd759ad
commit
f942f4a240
@ -198,8 +198,8 @@ struct InputBuffers {
|
|||||||
// Builds an InputBuffers object that describes the inputs to the computation.
|
// Builds an InputBuffers object that describes the inputs to the computation.
|
||||||
xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
|
xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
|
||||||
OpKernelContext* context, const xla::Shape& input_host_shape,
|
OpKernelContext* context, const xla::Shape& input_host_shape,
|
||||||
const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
|
const VariableUpdateMap& variable_updates, xla::Backend* backend,
|
||||||
se::Stream* stream) {
|
int device_ordinal, se::Stream* stream) {
|
||||||
profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
|
profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
|
||||||
OpInputList arg_list;
|
OpInputList arg_list;
|
||||||
TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
|
TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
|
||||||
@ -274,10 +274,8 @@ xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
|
|||||||
validate_shape(variables[i].index(), *variables[i].var()->tensor()));
|
validate_shape(variables[i].index(), *variables[i].var()->tensor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
|
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
|
||||||
xla::TransferManager* const transfer_manager =
|
xla::TransferManager* const transfer_manager = backend->transfer_manager();
|
||||||
node_context->transfer_manager();
|
|
||||||
const int device_ordinal = node_context->device_ordinal();
|
|
||||||
|
|
||||||
auto input_buffers = absl::make_unique<InputBuffers>(
|
auto input_buffers = absl::make_unique<InputBuffers>(
|
||||||
transfer_manager->HostShapeToDeviceShape(input_host_shape));
|
transfer_manager->HostShapeToDeviceShape(input_host_shape));
|
||||||
@ -411,7 +409,7 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
xla::TransferManager* const transfer_manager =
|
xla::TransferManager* const transfer_manager =
|
||||||
node_context->transfer_manager();
|
node_context->backend()->transfer_manager();
|
||||||
|
|
||||||
std::vector<TensorShape> output_tensor_shapes;
|
std::vector<TensorShape> output_tensor_shapes;
|
||||||
output_tensor_shapes.reserve(sub_elements);
|
output_tensor_shapes.reserve(sub_elements);
|
||||||
@ -434,7 +432,8 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
|
|||||||
TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
|
TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
|
||||||
TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
|
TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
|
||||||
|
|
||||||
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
|
se::DeviceMemoryAllocator* const allocator =
|
||||||
|
node_context->backend()->memory_allocator();
|
||||||
|
|
||||||
auto output_buffers =
|
auto output_buffers =
|
||||||
absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
|
absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
|
||||||
@ -633,10 +632,11 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
TpuNodeContext::Create(device_ordinal));
|
TpuNodeContext::Create(device_ordinal));
|
||||||
|
|
||||||
profiler::TraceMe trace_me(
|
profiler::TraceMe trace_me(
|
||||||
[&, device_ordinal] {
|
[device_ordinal, context] {
|
||||||
return absl::StrCat("TpuExecuteOp#device_ordinal=", device_ordinal,
|
return profiler::TraceMeEncode(
|
||||||
",id=", context->step_id(),
|
"TpuExecuteOp", {{"device_ordinal", device_ordinal},
|
||||||
",iter_num=", context->frame_iter().iter_id, "#");
|
{"id", context->step_id()},
|
||||||
|
{"iter_num", context->frame_iter().iter_id}});
|
||||||
},
|
},
|
||||||
/*level=*/2);
|
/*level=*/2);
|
||||||
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
|
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
|
||||||
@ -649,9 +649,9 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
// Shapes of the inputs and outputs, in xla::Shape form.
|
// Shapes of the inputs and outputs, in xla::Shape form.
|
||||||
const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
|
const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
|
||||||
|
|
||||||
xla::TransferManager* const transfer_manager =
|
xla::Backend* const backend = node_context->backend();
|
||||||
node_context->transfer_manager();
|
xla::TransferManager* const transfer_manager = backend->transfer_manager();
|
||||||
CHECK(context->op_device_context());
|
TF_RET_CHECK(context->op_device_context());
|
||||||
se::Stream* stream = context->op_device_context()->stream();
|
se::Stream* stream = context->op_device_context()->stream();
|
||||||
|
|
||||||
TF_RET_CHECK(proto->input_shapes_size() == 1);
|
TF_RET_CHECK(proto->input_shapes_size() == 1);
|
||||||
@ -666,8 +666,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
proto->output_tensor_shapes().size()));
|
proto->output_tensor_shapes().size()));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<InputBuffers> input_buffers,
|
std::unique_ptr<InputBuffers> input_buffers,
|
||||||
BuildComputationInputs(context, host_shape, variable_update_map,
|
BuildComputationInputs(context, host_shape, variable_update_map, backend,
|
||||||
node_context.get(), stream));
|
device_ordinal, stream));
|
||||||
|
|
||||||
// Ideally this should be the host-to-device stream from XlaDeviceContext.
|
// Ideally this should be the host-to-device stream from XlaDeviceContext.
|
||||||
// The particular anti-dependency this is avoiding (why we need a separate
|
// The particular anti-dependency this is avoiding (why we need a separate
|
||||||
@ -680,11 +680,11 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
// TODO(jmolloy): Add the necessary plumbing to obtain the proper
|
// TODO(jmolloy): Add the necessary plumbing to obtain the proper
|
||||||
// host-to-device stream here.
|
// host-to-device stream here.
|
||||||
TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
|
TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
|
||||||
node_context->BorrowStream(device_ordinal));
|
backend->BorrowStream(device_ordinal));
|
||||||
|
|
||||||
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
|
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
|
||||||
auto shaped_buffer =
|
auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
|
||||||
input_buffers->ToShapedBuffer(host_shape, allocator, device_ordinal);
|
allocator, device_ordinal);
|
||||||
if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
|
if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
|
||||||
shaped_buffer)) {
|
shaped_buffer)) {
|
||||||
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
||||||
@ -733,8 +733,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
<< shaped_buffer.ToString();
|
<< shaped_buffer.ToString();
|
||||||
|
|
||||||
std::vector<xla::ExecutionInput> input;
|
std::vector<xla::ExecutionInput> input;
|
||||||
input.emplace_back(
|
input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
|
||||||
xla::ExecutionInput(std::move(input_buffers->buffers), host_shape));
|
shaped_buffer.on_host_shape()));
|
||||||
|
|
||||||
// The buffers to be freed are in the `output` and will be automatically
|
// The buffers to be freed are in the `output` and will be automatically
|
||||||
// freed when it goes out of the scope. In async mode, this means the buffers
|
// freed when it goes out of the scope. In async mode, this means the buffers
|
||||||
|
@ -62,12 +62,12 @@ static bool tpu_cancellation_terminates_process = false;
|
|||||||
static bool tpu_cancellation_closes_chips = true;
|
static bool tpu_cancellation_closes_chips = true;
|
||||||
|
|
||||||
// Host-side runtime for transfers between TPU and host.
|
// Host-side runtime for transfers between TPU and host.
|
||||||
|
// TODO(b/161940519): Implement this class.
|
||||||
class HostTransferManager {
|
class HostTransferManager {
|
||||||
public:
|
public:
|
||||||
using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler;
|
explicit HostTransferManager(TpuNodeContext*, xla::Backend*) {}
|
||||||
|
|
||||||
explicit HostTransferManager(TpuNodeContext* node_context)
|
using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler;
|
||||||
: node_context_(node_context) {}
|
|
||||||
|
|
||||||
// Returns a function to be called when the TPU triggers a host command
|
// Returns a function to be called when the TPU triggers a host command
|
||||||
// interrupt while executing the current program.
|
// interrupt while executing the current program.
|
||||||
@ -76,8 +76,6 @@ class HostTransferManager {
|
|||||||
const std::string& rendezvous_key_base, OpKernelContext* ctx);
|
const std::string& rendezvous_key_base, OpKernelContext* ctx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TpuNodeContext* node_context_; // not owned
|
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
|
TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -417,7 +415,9 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
|
|||||||
profiler::TraceMe traceme("TPUExecute", 2);
|
profiler::TraceMe traceme("TPUExecute", 2);
|
||||||
TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
|
TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
|
||||||
TF_RET_CHECK(tpu_program != nullptr);
|
TF_RET_CHECK(tpu_program != nullptr);
|
||||||
VLOG(1) << "TPUExecute on device " << node_context->tensor_core_location();
|
VLOG(1) << "TPUExecute on device " << node_context->device_ordinal();
|
||||||
|
|
||||||
|
xla::Backend* backend = node_context->backend();
|
||||||
|
|
||||||
XlaDevice* device =
|
XlaDevice* device =
|
||||||
tensorflow::down_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
|
tensorflow::down_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
|
||||||
@ -425,19 +425,19 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
|
|||||||
|
|
||||||
// Create a HostTransferManager to handle Send/Recv operations from the TPU.
|
// Create a HostTransferManager to handle Send/Recv operations from the TPU.
|
||||||
std::shared_ptr<HostTransferManager> host_transfer_manager =
|
std::shared_ptr<HostTransferManager> host_transfer_manager =
|
||||||
std::make_shared<HostTransferManager>(node_context);
|
std::make_shared<HostTransferManager>(node_context, backend);
|
||||||
TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommmandHandler handler,
|
TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommmandHandler handler,
|
||||||
host_transfer_manager->Initialize(
|
host_transfer_manager->Initialize(
|
||||||
host_transfers, rendezvous_key_base, ctx));
|
host_transfers, rendezvous_key_base, ctx));
|
||||||
|
|
||||||
VLOG(2) << "Cloud TPU: Executing computation on device "
|
VLOG(2) << "Cloud TPU: Executing computation on device "
|
||||||
<< node_context->index_on_host();
|
<< node_context->device_ordinal();
|
||||||
|
|
||||||
xla::ExecutableRunOptions run_options;
|
xla::ExecutableRunOptions run_options;
|
||||||
run_options.set_stream(stream);
|
run_options.set_stream(stream);
|
||||||
run_options.set_device_assignment(device_assignment);
|
run_options.set_device_assignment(device_assignment);
|
||||||
run_options.set_rng_seed(rng_seed);
|
run_options.set_rng_seed(rng_seed);
|
||||||
run_options.set_allocator(node_context->memory_allocator());
|
run_options.set_allocator(backend->memory_allocator());
|
||||||
run_options.set_host_to_device_stream(host_to_device_stream);
|
run_options.set_host_to_device_stream(host_to_device_stream);
|
||||||
|
|
||||||
const xla::ServiceExecutableRunOptions service_run_options(run_options);
|
const xla::ServiceExecutableRunOptions service_run_options(run_options);
|
||||||
@ -460,7 +460,7 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
module->input_output_alias_config(),
|
module->input_output_alias_config(),
|
||||||
xla::HloInputOutputAliasConfig::CreateFromProto(
|
xla::HloInputOutputAliasConfig::CreateFromProto(
|
||||||
node_context->transfer_manager()->HostShapeToDeviceShape(
|
backend->transfer_manager()->HostShapeToDeviceShape(
|
||||||
module->config().entry_computation_layout().result_shape()),
|
module->config().entry_computation_layout().result_shape()),
|
||||||
hlo_metadata.hlo_module().input_output_alias()));
|
hlo_metadata.hlo_module().input_output_alias()));
|
||||||
TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
|
TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
|
||||||
@ -471,11 +471,11 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
|
|||||||
xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
|
xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(UpdateDynamicInputs(
|
TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, backend->memory_allocator(),
|
||||||
stream, node_context->memory_allocator(), &arguments, input_shapes));
|
&arguments, input_shapes));
|
||||||
|
|
||||||
auto tpu_executable = absl::make_unique<xla::TpuExecutable>(
|
auto tpu_executable = absl::make_unique<xla::TpuExecutable>(
|
||||||
tpu_program, std::move(module), handler);
|
tpu_program, std::move(module), /*host_command_handler=*/handler);
|
||||||
|
|
||||||
const int32 device_ordinal = node_context->device_ordinal();
|
const int32 device_ordinal = node_context->device_ordinal();
|
||||||
CancellationToken token;
|
CancellationToken token;
|
||||||
|
@ -182,13 +182,12 @@ cc_library(
|
|||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
":tpu_node_context_c_api_hdrs",
|
":tpu_node_context_c_api_hdrs",
|
||||||
":tpu_platform_interface",
|
":tpu_platform_interface",
|
||||||
":tpu_transfer_manager_base",
|
|
||||||
"//tensorflow/compiler/xla/service",
|
"//tensorflow/compiler/xla/service",
|
||||||
"//tensorflow/compiler/xla/service:backend",
|
"//tensorflow/compiler/xla/service:backend",
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
|
||||||
"//tensorflow/compiler/xla/service:stream_pool",
|
"//tensorflow/compiler/xla/service:stream_pool",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
|
@ -12,13 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/backend.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
|
||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||||
|
|
||||||
@ -36,6 +33,8 @@ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
|
|||||||
tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
|
tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
|
||||||
status.c_status);
|
status.c_status);
|
||||||
if (!status.status().ok()) {
|
if (!status.status().ok()) {
|
||||||
|
// TpuNodeContext_CreateFn allocates a new XLA_TpuNodeContext regardless of
|
||||||
|
// status. It needs to be freed if it's not given to a TpuNodeContext below.
|
||||||
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
|
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
@ -46,13 +45,6 @@ TpuNodeContext::~TpuNodeContext() {
|
|||||||
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
|
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
|
||||||
Status TpuNodeContext::Initialize(int device_ordinal) {
|
|
||||||
StatusHelper status;
|
|
||||||
TpuNodeContext_Initialize(device_ordinal, status.c_status);
|
|
||||||
return status.status();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
Status TpuNodeContext::StopChipHeartbeats() {
|
Status TpuNodeContext::StopChipHeartbeats() {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
@ -68,21 +60,20 @@ Status TpuNodeContext::CloseTpuHost() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
tensorflow::tpu::TpuPlatformInterface* TpuNodeContext::platform() {
|
Status TpuNodeContext::Initialize(int device_ordinal) {
|
||||||
|
StatusHelper status;
|
||||||
|
TpuNodeContext_Initialize(device_ordinal, status.c_status);
|
||||||
|
return status.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
TpuPlatformInterface* TpuNodeContext::platform() {
|
||||||
return TpuPlatformInterface::GetRegisteredPlatform();
|
return TpuPlatformInterface::GetRegisteredPlatform();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
int TpuNodeContext::device_ordinal() const { return device_ordinal_; }
|
||||||
stream_executor::DeviceMemoryAllocator* TpuNodeContext::memory_allocator() {
|
|
||||||
static stream_executor::StreamExecutorMemoryAllocator* memory_allocator =
|
|
||||||
new stream_executor::StreamExecutorMemoryAllocator(
|
|
||||||
platform(),
|
|
||||||
xla::PlatformUtil::GetStreamExecutors(platform()).ValueOrDie());
|
|
||||||
return memory_allocator;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */
|
xla::Backend* TpuNodeContext::backend() const {
|
||||||
xla::Backend* TpuNodeContext::backend() {
|
|
||||||
static xla::Backend* backend =
|
static xla::Backend* backend =
|
||||||
xla::Backend::CreateBackend(
|
xla::Backend::CreateBackend(
|
||||||
xla::BackendOptions().set_platform(platform()))
|
xla::BackendOptions().set_platform(platform()))
|
||||||
@ -91,21 +82,8 @@ xla::Backend* TpuNodeContext::backend() {
|
|||||||
return backend;
|
return backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
stream_executor::StreamExecutor* TpuNodeContext::stream_executor() const {
|
||||||
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
|
return backend()->stream_executor(device_ordinal_).ValueOrDie();
|
||||||
int device_ordinal) {
|
|
||||||
return backend()->BorrowStream(device_ordinal);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */
|
|
||||||
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
|
|
||||||
stream_executor::StreamExecutor* executor) {
|
|
||||||
return backend()->BorrowStream(executor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */
|
|
||||||
xla::TransferManager* TpuNodeContext::transfer_manager() {
|
|
||||||
return xla::TransferManager::GetForPlatform(platform()).ValueOrDie();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
@ -33,6 +34,11 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
|
// A TpuNodeContext object represents a specific TPU node (core). The static
|
||||||
|
// class methods represent host-wide actions.
|
||||||
|
//
|
||||||
|
// First call Initialize in a freshly reset system. Then call Create to talk to
|
||||||
|
// individual nodes.
|
||||||
class TpuNodeContext final {
|
class TpuNodeContext final {
|
||||||
public:
|
public:
|
||||||
using Status = stream_executor::port::Status;
|
using Status = stream_executor::port::Status;
|
||||||
@ -47,41 +53,25 @@ class TpuNodeContext final {
|
|||||||
}
|
}
|
||||||
~TpuNodeContext();
|
~TpuNodeContext();
|
||||||
|
|
||||||
TpuNodeContext(const TpuNodeContext&) = delete;
|
|
||||||
TpuNodeContext& operator=(const TpuNodeContext&) = delete;
|
|
||||||
|
|
||||||
static Status Initialize(int device_ordinal);
|
|
||||||
|
|
||||||
static Status StopChipHeartbeats();
|
static Status StopChipHeartbeats();
|
||||||
|
|
||||||
static Status CloseTpuHost();
|
static Status CloseTpuHost();
|
||||||
|
|
||||||
static tensorflow::tpu::TpuPlatformInterface* platform();
|
static Status Initialize(int device_ordinal);
|
||||||
|
|
||||||
static stream_executor::DeviceMemoryAllocator* memory_allocator();
|
static TpuPlatformInterface* platform();
|
||||||
|
|
||||||
static xla::TransferManager* transfer_manager();
|
int device_ordinal() const;
|
||||||
|
|
||||||
static xla::Backend* backend();
|
xla::Backend* backend() const;
|
||||||
|
|
||||||
static StatusOr<xla::StreamPool::Ptr> BorrowStream(int device_ordinal);
|
stream_executor::StreamExecutor* stream_executor() const;
|
||||||
|
|
||||||
static StatusOr<xla::StreamPool::Ptr> BorrowStream(
|
|
||||||
stream_executor::StreamExecutor* executor);
|
|
||||||
|
|
||||||
stream_executor::StreamExecutor* stream_executor() {
|
|
||||||
LOG(FATAL) << "Not implemented yet.";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string tensor_core_location() { LOG(FATAL) << "Not implemented yet."; }
|
|
||||||
|
|
||||||
int index_on_host() { LOG(FATAL) << "Not implemented yet."; }
|
|
||||||
|
|
||||||
int device_ordinal() const { return device_ordinal_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int device_ordinal_;
|
const int device_ordinal_;
|
||||||
XLA_TpuNodeContext* const node_context_;
|
XLA_TpuNodeContext* const node_context_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(TpuNodeContext);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
@ -26,19 +26,20 @@ XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
|||||||
SE_Status* status);
|
SE_Status* status);
|
||||||
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
||||||
|
|
||||||
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
|
|
||||||
|
|
||||||
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
||||||
|
|
||||||
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
||||||
|
|
||||||
|
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
struct TfTpu_NodeContextApiFn {
|
struct TfTpu_NodeContextApiFn {
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user