Introduce TPUReshardVariable op to open source TensorFlow
PiperOrigin-RevId: 359089840 Change-Id: I7dd20d202e7881cb5c3ee0e5a52146d34e3df058
This commit is contained in:
parent
31a8ae453c
commit
7c4e09623d
@ -35,6 +35,7 @@ tf_kernel_library(
|
||||
":tpu_configuration_ops",
|
||||
":tpu_execute_op",
|
||||
":tpu_handle_to_key_op",
|
||||
":tpu_reshard_variables_op",
|
||||
":transfer_ops",
|
||||
],
|
||||
)
|
||||
@ -278,9 +279,11 @@ cc_library(
|
||||
hdrs = ["tpu_program_group_interface.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_executable_info_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
@ -883,3 +886,67 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_reshard_variables_op",
|
||||
srcs = ["tpu_reshard_variables_op.cc"],
|
||||
hdrs = ["tpu_reshard_variables_op.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
":tpu_compilation_cache_lookup",
|
||||
":tpu_op_consts",
|
||||
":tpu_program_group",
|
||||
":tpu_reshard_variables_op_util",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/jit:xla_tensor",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu:tpu_execute",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_reshard_variables_op_util",
|
||||
srcs = ["tpu_reshard_variables_op_util.cc"],
|
||||
hdrs = ["tpu_reshard_variables_op_util.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_lookup",
|
||||
":tpu_op_consts",
|
||||
":tpu_program_group",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/jit:xla_tensor",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu:tpu_execute",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -127,12 +127,12 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
||||
|
||||
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
|
||||
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
|
||||
const XLA_TpuProgram* tpu_program(int index) const;
|
||||
const XLA_TpuProgram* tpu_program(int index) const override;
|
||||
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
|
||||
|
||||
const TPUExecutableInfoProto& executable_info(int index) const;
|
||||
const TPUExecutableInfoProto& executable_info(int index) const override;
|
||||
|
||||
const TPUHostTransferInfoProto& host_transfer_info(int index) const;
|
||||
const TPUHostTransferInfoProto& host_transfer_info(int index) const override;
|
||||
void set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas);
|
||||
const xla::HloProto* hlo_metadata(int index) const;
|
||||
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
|
||||
|
@ -26,6 +26,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
@ -66,6 +68,16 @@ class TpuProgramGroupInterface {
|
||||
// Gets may modify variables value of the TPU program for the given core
|
||||
// `index`.
|
||||
virtual bool may_modify_variables(int index) const = 0;
|
||||
|
||||
// Get Executable Info Proto
|
||||
virtual const TPUExecutableInfoProto& executable_info(int index) const = 0;
|
||||
|
||||
// Get HostTransferInfo Proto
|
||||
virtual const TPUHostTransferInfoProto& host_transfer_info(
|
||||
int index) const = 0;
|
||||
|
||||
// Get XLA_TpuProgram Proto
|
||||
virtual const XLA_TpuProgram* tpu_program(int index) const = 0;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
283
tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
Normal file
283
tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
Normal file
@ -0,0 +1,283 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_execute.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace reshard_util = ::tensorflow::tpu::reshard_variables;
|
||||
|
||||
TPUReshardVariablesOpKernel::TPUReshardVariablesOpKernel(
|
||||
OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context, /* is_deferred = */ true) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("N", &num_vars_));
|
||||
}
|
||||
|
||||
void TPUReshardVariablesOpKernel::ComputeAsync(OpKernelContext* context,
|
||||
DoneCallback done) {
|
||||
// If TPU launches are asynchronous, then perform the launch on this thread
|
||||
// to avoid a thread hop, which has an observable latency cost.
|
||||
OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
|
||||
done();
|
||||
}
|
||||
|
||||
Status TPUReshardVariablesOpKernel::DoWork(OpKernelContext* context) {
|
||||
VLOG(1) << "Cloud TPU: TPUReshardVariablesOpKernel::DoWork";
|
||||
TF_RET_CHECK(context->input_dtype(num_vars_) == DT_STRING);
|
||||
const Tensor* new_format_key;
|
||||
TF_RETURN_IF_ERROR(context->input("new_format_key", &new_format_key));
|
||||
TF_RETURN_IF_ERROR(reshard_util::CheckIsValidKey(*new_format_key));
|
||||
|
||||
TF_RET_CHECK(context->input_dtype(num_vars_ + 1) == DT_RESOURCE);
|
||||
const ResourceHandle& handle = HandleFromInput(context, num_vars_ + 1);
|
||||
Var* format_state_var;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
context, handle, &format_state_var, [new_format_key](Var** ptr) {
|
||||
*ptr = new Var(new_format_key->dtype());
|
||||
return Status::OK();
|
||||
}));
|
||||
mutex_lock ml(*format_state_var->mu());
|
||||
const bool initialized = format_state_var->is_initialized;
|
||||
if (initialized) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
reshard_util::CheckIsValidKey(*format_state_var->tensor()));
|
||||
}
|
||||
|
||||
const bool state_is_default =
|
||||
!initialized || reshard_util::IsDefaultKey(*format_state_var->tensor());
|
||||
const bool new_format_is_default =
|
||||
reshard_util::IsDefaultKey(*new_format_key);
|
||||
|
||||
if ((state_is_default && new_format_is_default) ||
|
||||
(initialized && format_state_var->tensor()->vec<tstring>()(2) ==
|
||||
new_format_key->vec<tstring>()(2))) {
|
||||
VLOG(1) << "Sharding unchanged, nothing to do.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (!state_is_default) {
|
||||
// Convert the current format to default (unsharded).
|
||||
VLOG(1) << "Unsharding with key: "
|
||||
<< format_state_var->tensor()->vec<tstring>()(2);
|
||||
TF_RETURN_IF_ERROR(
|
||||
DoTpuExecute(context, *format_state_var->tensor(),
|
||||
tpu::CompilationCacheFetchTarget::UNSHARDING));
|
||||
}
|
||||
|
||||
if (!new_format_is_default) {
|
||||
// Convert the new format.
|
||||
VLOG(1) << "Sharding with key: " << new_format_key->vec<tstring>()(2);
|
||||
TF_RETURN_IF_ERROR(DoTpuExecute(
|
||||
context, *new_format_key, tpu::CompilationCacheFetchTarget::SHARDING));
|
||||
}
|
||||
|
||||
// Change the state.
|
||||
*format_state_var->tensor() = *new_format_key;
|
||||
format_state_var->is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TPUReshardVariablesOpKernel::DoTpuExecute(
|
||||
OpKernelContext* context, const Tensor& format_key,
|
||||
tpu::CompilationCacheFetchTarget fetch_target) {
|
||||
const XlaDevice::Metadata* metadata;
|
||||
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
|
||||
const int device_ordinal = metadata->device_ordinal();
|
||||
|
||||
// We are guaranteed that the underlying object won't be deleted out from
|
||||
// under us
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
|
||||
tpu::TpuNodeContext::Create(device_ordinal));
|
||||
|
||||
profiler::TraceMe trace_me(
|
||||
[device_ordinal] {
|
||||
return profiler::TraceMeEncode("TPUReshardVariablesOpKernel",
|
||||
{{"device_ordinal", device_ordinal}});
|
||||
},
|
||||
/*level=*/2);
|
||||
profiler::TraceMe trace_me_init("TPUReshardVariablesOpKernel::Init",
|
||||
/*level=*/2);
|
||||
|
||||
string rendezvous_key_base;
|
||||
std::unique_ptr<tpu::CompilationCacheEntryRef> entry_ref;
|
||||
TF_RETURN_IF_ERROR(reshard_util::GetComputationCacheEntry(
|
||||
format_key, &rendezvous_key_base, &entry_ref, fetch_target));
|
||||
tpu::TpuCompilationCacheEntry entry = entry_ref->get();
|
||||
if (entry.tpu_program_group() == nullptr) {
|
||||
VLOG(2) << "Sharding/unsharding program does not exist, so this is default "
|
||||
"sharding.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const tpu::TpuProgramGroupInterface* tpu_program_group =
|
||||
entry.tpu_program_group();
|
||||
const int core_index = entry.core_index();
|
||||
const TPUExecutableInfoProto& executable_info_proto =
|
||||
tpu_program_group->executable_info(core_index);
|
||||
const TPUExecutableInfoProto* executable = &executable_info_proto;
|
||||
|
||||
xla::Backend* const backend = node_interfaces->backend();
|
||||
xla::TransferManager* const transfer_manager = backend->transfer_manager();
|
||||
|
||||
CHECK(context->op_device_context());
|
||||
se::Stream* stream = context->op_device_context()->stream();
|
||||
|
||||
TF_RET_CHECK(executable->input_shapes_size() == 1);
|
||||
xla::Shape host_shape(executable->input_shapes(0));
|
||||
std::vector<VariableInfo> variables;
|
||||
for (int i = 0; i < num_vars_; ++i) {
|
||||
TF_RET_CHECK(context->input_dtype(i) == DT_RESOURCE);
|
||||
const ResourceHandle& handle = HandleFromInput(context, i);
|
||||
Var* variable;
|
||||
TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
|
||||
variables.push_back(VariableInfo(i, handle.name(), variable));
|
||||
}
|
||||
|
||||
// Block for previous TPUExecute ops so that the memory used for them could be
|
||||
// freed.
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
// Lock variables to prevent concurrent access.
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
|
||||
|
||||
// Build input buffers.
|
||||
TF_ASSIGN_OR_RETURN(auto input_buffers, reshard_util::BuildInputBuffers(
|
||||
context, variables, host_shape,
|
||||
backend, device_ordinal, stream));
|
||||
xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(),
|
||||
device_ordinal);
|
||||
shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>(
|
||||
[](xla::MaybeOwningDeviceMemory* buffer) {
|
||||
CHECK(buffer);
|
||||
return buffer->AsDeviceMemoryBase();
|
||||
}));
|
||||
|
||||
// Write input root tuple.
|
||||
TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
|
||||
backend->BorrowStream(device_ordinal));
|
||||
if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
|
||||
shaped_buffer)) {
|
||||
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
||||
transfer_stream_ptr.get(), shaped_buffer));
|
||||
stream->ThenWaitFor(transfer_stream_ptr.get());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
|
||||
}
|
||||
VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
|
||||
|
||||
TF_RET_CHECK(!executable->has_session_module())
|
||||
<< "session module not supported in sharding/unsharding program.";
|
||||
|
||||
auto definition_event = std::make_shared<se::Event>(stream->parent());
|
||||
TF_RET_CHECK(definition_event->Init())
|
||||
<< "TPU definition event initialization failed";
|
||||
|
||||
trace_me_init.Stop();
|
||||
|
||||
// Execute the program.
|
||||
std::unique_ptr<xla::DeviceAssignment> device_assignment;
|
||||
if (executable->has_device_assignment()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
device_assignment,
|
||||
xla::DeviceAssignment::Deserialize(executable->device_assignment()));
|
||||
}
|
||||
std::vector<xla::ExecutionInput> input;
|
||||
input.emplace_back(xla::ExecutionInput(std::move(input_buffers),
|
||||
shaped_buffer.on_host_shape()));
|
||||
|
||||
const TPUHostTransferInfoProto& host_transfer_info =
|
||||
tpu_program_group->host_transfer_info(core_index);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::ExecutionOutput output,
|
||||
TPUExecute(*executable, host_transfer_info,
|
||||
*tpu_program_group->hlo_metadatas()[core_index],
|
||||
std::move(input), rendezvous_key_base, GetXLARandomSeed(),
|
||||
node_interfaces.get(), device_assignment.get(),
|
||||
context->cancellation_manager(), context, stream,
|
||||
transfer_stream_ptr.get(),
|
||||
tpu_program_group->tpu_program(core_index)));
|
||||
|
||||
stream->ThenRecordEvent(definition_event.get());
|
||||
|
||||
// Assign the new buffers to the variables.
|
||||
xla::ScopedShapedBuffer result = output.ConsumeResult();
|
||||
|
||||
// Only perform compaction when sharding.
|
||||
// NOTE: Compaction is not supported on some TPUs, see b/168322060 for details
|
||||
if (node_interfaces->CompactionSupported(device_ordinal) &&
|
||||
fetch_target == tpu::CompilationCacheFetchTarget::SHARDING) {
|
||||
// Block until program execution is done so that input, output, and program
|
||||
// cache memory can be actually released.
|
||||
TF_RETURN_IF_ERROR(transfer_stream_ptr->BlockHostUntilDone());
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
{
|
||||
// Explicitly release any RAII objects owning on-device allocations.
|
||||
auto unused = output.ConsumeToBeReleased();
|
||||
}
|
||||
// Release variables holding inputs.
|
||||
for (int i = 0; i < variables.size(); ++i) {
|
||||
*variables[i].var()->tensor() = Tensor();
|
||||
}
|
||||
// Flush on-device program memory cache.
|
||||
TF_RETURN_IF_ERROR(
|
||||
reshard_util::FlushProgramMemory(backend->platform(), device_ordinal));
|
||||
TF_RETURN_IF_ERROR(reshard_util::PerformCompaction(stream));
|
||||
}
|
||||
return reshard_util::UpdateOutputVariables(
|
||||
context, std::move(result), executable->output_tensor_shapes(), backend,
|
||||
stream, device_ordinal, variables, definition_event);
|
||||
}
|
||||
|
||||
TPUReshardVariablesOpKernel::~TPUReshardVariablesOpKernel() = default;
|
||||
|
||||
#if !defined(PLATFORM_GOOGLE)
|
||||
REGISTER_KERNEL_BUILDER(Name("TPUReshardVariables")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("format_state_var")
|
||||
.HostMemory("new_format_key"),
|
||||
TPUReshardVariablesOpKernel);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
52
tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h
Normal file
52
tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Op that changes the sharding state for a set of on-device variables. The
|
||||
// sharding state is represented as the key of the compilation that generated
|
||||
// the sharding/unsharding programs along with the main program. The op checks
|
||||
// if the current sharding state matches the desired one, and if not, uses the
|
||||
// sharding/unsharding programs to transform the variables to the desired state.
|
||||
class TPUReshardVariablesOpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit TPUReshardVariablesOpKernel(OpKernelConstruction* context);
|
||||
~TPUReshardVariablesOpKernel() override;
|
||||
|
||||
void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
|
||||
|
||||
private:
|
||||
Status DoWork(OpKernelContext* context);
|
||||
Status DoTpuExecute(OpKernelContext* context, const Tensor& format_key,
|
||||
tpu::CompilationCacheFetchTarget fetch_target);
|
||||
|
||||
int64 num_vars_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(TPUReshardVariablesOpKernel);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_
|
314
tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.cc
Normal file
314
tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.cc
Normal file
@ -0,0 +1,314 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_execute.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace reshard_variables {
|
||||
|
||||
Status FlushProgramMemory(se::Platform* platform, int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
|
||||
tpu::TpuNodeContext::Create(device_ordinal));
|
||||
|
||||
auto* executor = tensorflow::down_cast<tpu::TpuExecutorInterface*>(
|
||||
node_interfaces->stream_executor()->implementation());
|
||||
return executor->UnloadAllPrograms();
|
||||
}
|
||||
|
||||
Status CheckIsValidKey(const Tensor& key) {
|
||||
if (!TensorShapeUtils::IsVector(key.shape()) ||
|
||||
key.shape().dim_size(0) != 3) {
|
||||
return errors::InvalidArgument(
|
||||
"new_format_key argument to TPUReshardVariables must be a 3-element "
|
||||
"vector");
|
||||
}
|
||||
if (key.dtype() != DT_STRING) {
|
||||
return errors::InvalidArgument(
|
||||
"new_format_key argument to TPUReshardVariables must be DT_STRING "
|
||||
"type");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsDefaultKey(const Tensor& key) { return key.vec<tstring>()(0).empty(); }
|
||||
|
||||
// Looks up the input `key` in the compilation cache, populating
|
||||
// `*rendezvous_key_base` and `*entry`.
|
||||
Status GetComputationCacheEntry(
|
||||
const Tensor& key, string* rendezvous_key_base,
|
||||
std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
|
||||
tpu::CompilationCacheFetchTarget fetch_target) {
|
||||
profiler::TraceMe trace_me("TPUReshardVariablesOpKernel::LookupProto",
|
||||
/*level=*/2);
|
||||
TF_RETURN_IF_ERROR(CheckIsValidKey(key));
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
tpu::TpuCompilationCacheLookup* proto_lookup;
|
||||
TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
|
||||
tpu::kCompiledProtoCacheResourceName,
|
||||
&proto_lookup));
|
||||
core::ScopedUnref lookup_unref(proto_lookup);
|
||||
TF_RETURN_IF_ERROR(
|
||||
proto_lookup->Lookup(key.vec<tstring>()(0), entry, fetch_target));
|
||||
*rendezvous_key_base = key.vec<tstring>()(1);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builds an InputBuffers object that describes the inputs to the computation.
|
||||
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>> BuildInputBuffers(
|
||||
OpKernelContext* context, const std::vector<VariableInfo>& variables,
|
||||
const xla::Shape& input_host_shape, xla::Backend* backend,
|
||||
int device_ordinal, se::Stream* stream) {
|
||||
profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
|
||||
OpInputList var_list;
|
||||
TF_RETURN_IF_ERROR(context->input_list("vars", &var_list));
|
||||
|
||||
if (var_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of variables (", var_list.size(),
|
||||
") does not match input shape: ",
|
||||
xla::ShapeUtil::TupleElementCount(input_host_shape));
|
||||
}
|
||||
|
||||
auto validate_shape = [&](int i, const Tensor& tensor) {
|
||||
const xla::Shape& expected =
|
||||
xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
|
||||
VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
|
||||
|
||||
if (xla_tensor == nullptr) {
|
||||
// FromTensor failed; tensor must be empty.
|
||||
if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
|
||||
return errors::InvalidArgument(
|
||||
"Run-time shape mismatch for TPUExecute argument[", i, "] (",
|
||||
context->op_kernel().requested_input(i), "). Expected ",
|
||||
expected.DebugString(),
|
||||
"; got empty tensor. If you are running "
|
||||
"with TF2 TPU, make sure you set `drop_remainder=False` when "
|
||||
"calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
|
||||
"size can be handled");
|
||||
}
|
||||
} else {
|
||||
const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
|
||||
if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Run-time shape mismatch for TPUReshardVariables argument[", i,
|
||||
"] (", context->op_kernel().requested_input(i), "). Expected ",
|
||||
expected.DebugString(), "; got ", xla_shape.DebugString());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
for (int i = 0; i < variables.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
validate_shape(variables[i].index(), *variables[i].var()->tensor()));
|
||||
}
|
||||
|
||||
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
|
||||
xla::TransferManager* const transfer_manager = backend->transfer_manager();
|
||||
|
||||
xla::ShapeTree<xla::MaybeOwningDeviceMemory> input_buffers(
|
||||
transfer_manager->HostShapeToDeviceShape(input_host_shape));
|
||||
|
||||
// Allocates a buffer for the root tuple.
|
||||
const int64 root_size =
|
||||
transfer_manager->GetByteSizeRequirement(input_buffers.shape());
|
||||
TF_ASSIGN_OR_RETURN(*input_buffers.mutable_element({}),
|
||||
allocator->Allocate(device_ordinal, root_size));
|
||||
|
||||
auto set_input_buffers_helper = [&](int arg_index, xla::ShapedBuffer* buffers,
|
||||
bool owning = false) {
|
||||
buffers->buffers().ForEachMutableElement(
|
||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
xla::ShapeIndex in_index = {arg_index};
|
||||
for (int64 j : index) {
|
||||
in_index.push_back(j);
|
||||
}
|
||||
if (owning) {
|
||||
*input_buffers.mutable_element(in_index) =
|
||||
se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
|
||||
*buffer = se::DeviceMemoryBase();
|
||||
} else {
|
||||
*input_buffers.mutable_element(in_index) = *buffer;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
|
||||
// buffers for zero-element tensors where required.
|
||||
auto assign_input = [&](int i, const Tensor& tensor) -> xla::Status {
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
|
||||
|
||||
// Size 0 tensors have no backing XlaTensor, but may still need to have
|
||||
// tuple buffers allocated.
|
||||
if (xla_tensor == nullptr) {
|
||||
CHECK_EQ(tensor.NumElements(), 0);
|
||||
const xla::Shape& host_shape =
|
||||
xla::ShapeUtil::GetSubshape(input_host_shape, {i});
|
||||
TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
host_shape, allocator, device_ordinal));
|
||||
set_input_buffers_helper(/*arg_index=*/i, &buffers);
|
||||
} else {
|
||||
set_input_buffers_helper(/*arg_index=*/i, &xla_tensor->shaped_buffer(),
|
||||
tensor.RefCountIsOne());
|
||||
xla_tensor->WaitForDefinitionEventOnStream(stream);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
for (int i = 0; i < var_list.size(); ++i) {
|
||||
TF_RET_CHECK(var_list[i].dtype() == DT_RESOURCE);
|
||||
TF_RETURN_IF_ERROR(assign_input(i, *variables[i].var()->tensor()));
|
||||
}
|
||||
|
||||
return std::move(input_buffers);
|
||||
}
|
||||
|
||||
// Perform a compaction to reduce fragmentation.
|
||||
Status PerformCompaction(stream_executor::Stream* stream) {
|
||||
profiler::TraceMe trace_me("PerformCompaction", /*level=*/2);
|
||||
auto* ds_executor =
|
||||
down_cast<tpu::TpuExecutorInterface*>(stream->parent()->implementation());
|
||||
TF_RETURN_IF_ERROR(ds_executor->EnqueueCompactionOnStreamForHbm(stream));
|
||||
// LoadProgram and GetOrCreateConstantHandle are not managed by stream
|
||||
// dependencies but they write to shared memory, so we need to block here to
|
||||
// prevent those operations from racing.
|
||||
return stream->BlockHostUntilDone();
|
||||
}
|
||||
|
||||
// Updates the variables to the execution result's buffers, and deallocates the
|
||||
// root tuple buffer.
|
||||
Status UpdateOutputVariables(
|
||||
OpKernelContext* context, xla::ScopedShapedBuffer result_buffers,
|
||||
absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
|
||||
xla::Backend* backend, se::Stream* stream, int device_ordinal,
|
||||
const std::vector<VariableInfo>& variables,
|
||||
const std::shared_ptr<se::Event>& definition_event) {
|
||||
profiler::TraceMe trace_me("UpdateOutputVariables", /*level=*/2);
|
||||
// Shapes of the outputs, in TensorShape form.
|
||||
const int64 sub_elements =
|
||||
xla::ShapeUtil::TupleElementCount(result_buffers.on_host_shape());
|
||||
if (sub_elements != output_tensor_shape_protos.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched numbers of output shapes: ", sub_elements, " vs. ",
|
||||
output_tensor_shape_protos.size());
|
||||
}
|
||||
|
||||
if (sub_elements != variables.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Output count does not equal varaible count: ", sub_elements, " vs. ",
|
||||
variables.size());
|
||||
}
|
||||
|
||||
std::vector<TensorShape> output_tensor_shapes;
|
||||
output_tensor_shapes.reserve(sub_elements);
|
||||
for (int64 i = 0; i < sub_elements; ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
|
||||
TensorShape shape(*output_tensor_shape_protos[i]);
|
||||
const xla::Shape& xla_shape =
|
||||
xla::ShapeUtil::GetSubshape(result_buffers.on_host_shape(), {i});
|
||||
if (!xla_shape.IsArray() ||
|
||||
xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched number of elements in output shape: ",
|
||||
xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
|
||||
}
|
||||
output_tensor_shapes.push_back(shape);
|
||||
VLOG(2) << "Output " << i << " shape " << shape.DebugString();
|
||||
}
|
||||
|
||||
// Build a shaped buffer for the outputs.
|
||||
TF_RET_CHECK(result_buffers.on_host_shape().IsTuple());
|
||||
TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(result_buffers.on_host_shape()));
|
||||
|
||||
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
|
||||
|
||||
auto output_buffers = result_buffers.release();
|
||||
const xla::Shape& output_host_shape = output_buffers.on_host_shape();
|
||||
const xla::Shape& output_device_shape = output_buffers.on_device_shape();
|
||||
|
||||
// Transfers ownership of the buffers that back XLA computation output 'i'
|
||||
// to 'output_tensor'.
|
||||
auto transfer_buffers = [&](int i, Tensor* output_tensor) {
|
||||
const xla::Shape& host_shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
|
||||
const xla::Shape& device_shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
|
||||
if (output_tensor->NumElements() > 0) {
|
||||
xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
|
||||
device_ordinal);
|
||||
shaped_buffer.buffers().ForEachMutableElement(
|
||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
xla::ShapeIndex out_index = {i};
|
||||
for (int64 j : index) {
|
||||
out_index.push_back(j);
|
||||
}
|
||||
*buffer = output_buffers.buffers().element(out_index);
|
||||
});
|
||||
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
|
||||
xla_tensor->ResetDefinitionEvent(definition_event, stream);
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < variables.size(); ++i) {
|
||||
PersistentTensor unused;
|
||||
Tensor* output_tensor;
|
||||
TF_RETURN_IF_ERROR(context->allocate_persistent(
|
||||
variables[i].var()->tensor()->dtype(), output_tensor_shapes[i], &unused,
|
||||
&output_tensor));
|
||||
*variables[i].var()->tensor() = *output_tensor;
|
||||
transfer_buffers(i, output_tensor);
|
||||
}
|
||||
return allocator->Deallocate(output_buffers.device_ordinal(),
|
||||
output_buffers.buffer({}));
|
||||
}
|
||||
|
||||
} // namespace reshard_variables
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
61
tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h
Normal file
61
tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace reshard_variables {
|
||||
|
||||
Status FlushProgramMemory(se::Platform* platform, int device_ordinal);
|
||||
|
||||
Status CheckIsValidKey(const Tensor& key);
|
||||
|
||||
bool IsDefaultKey(const Tensor& key);
|
||||
|
||||
Status GetComputationCacheEntry(
|
||||
const Tensor& key, string* rendezvous_key_base,
|
||||
std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
|
||||
tpu::CompilationCacheFetchTarget fetch_target);
|
||||
|
||||
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>> BuildInputBuffers(
|
||||
OpKernelContext* context, const std::vector<VariableInfo>& variables,
|
||||
const xla::Shape& input_host_shape, xla::Backend* backend,
|
||||
int device_ordinal, se::Stream* stream);
|
||||
|
||||
Status PerformCompaction(stream_executor::Stream* stream);
|
||||
|
||||
Status UpdateOutputVariables(
|
||||
OpKernelContext* context, xla::ScopedShapedBuffer result_buffers,
|
||||
absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
|
||||
xla::Backend* backend, se::Stream* stream, int device_ordinal,
|
||||
const std::vector<VariableInfo>& variables,
|
||||
const std::shared_ptr<se::Event>& definition_event);
|
||||
|
||||
} // namespace reshard_variables
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_
|
@ -426,6 +426,8 @@ void TpuNodeContext_CloseTpuHost(TF_Status* status);
|
||||
|
||||
void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status);
|
||||
|
||||
bool TpuNodeContext_CompactionSupported(int device_ordinal);
|
||||
|
||||
// Globally initialize the TPU system for inference.
|
||||
TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer();
|
||||
|
||||
@ -495,6 +497,7 @@ struct TfTpu_OpsApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer);
|
||||
};
|
||||
|
@ -327,6 +327,21 @@ bool TpuExecutor::MemcpyDeviceToDevice(
|
||||
LOG(FATAL) << __func__ << " not supported on TpuExecutor";
|
||||
}
|
||||
|
||||
Status TpuExecutor::UnloadAllPrograms() {
|
||||
StatusHelper status;
|
||||
tpu::ExecutorApiFn()->TpuExecutor_UnloadAllProgramsFn(executor_,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) {
|
||||
StatusHelper status;
|
||||
tpu::ExecutorApiFn()->TpuExecutor_EnqueueCompactionOnStreamForHbmFn(
|
||||
executor_, get_stream(compaction_stream->implementation()),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
struct HostCallbackContext {
|
||||
std::function<Status()> callback;
|
||||
};
|
||||
|
@ -155,6 +155,10 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
|
||||
|
||||
Status WaitForOutfeedReady(int32 outfeed_queue_index);
|
||||
|
||||
Status UnloadAllPrograms() override;
|
||||
|
||||
Status EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) override;
|
||||
|
||||
const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
|
||||
return *platform_;
|
||||
}
|
||||
|
@ -122,6 +122,12 @@ void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor,
|
||||
void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor);
|
||||
bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor);
|
||||
|
||||
void TpuExecutor_UnloadAllPrograms(SE_StreamExecutor* executor,
|
||||
TF_Status* status);
|
||||
void TpuExecutor_EnqueueCompactionOnStreamForHbm(SE_StreamExecutor* executor,
|
||||
SE_Stream* compaction_stream,
|
||||
TF_Status* status);
|
||||
|
||||
SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
|
||||
void TpuStream_Free(SE_Stream*);
|
||||
void* TpuStream_Stream(SE_Stream*);
|
||||
@ -389,6 +395,8 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockUntilDoneOrFailed);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SyncAndForgetFailedStreams);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronizeAllActivity);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_UnloadAllPrograms);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_EnqueueCompactionOnStreamForHbm);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Free);
|
||||
|
@ -58,6 +58,13 @@ class TpuExecutorInterface
|
||||
virtual TpuCoreLocationExternal GetCoreLocationExternal() const {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
|
||||
virtual Status UnloadAllPrograms() { LOG(FATAL) << "Unimplemented."; }
|
||||
|
||||
virtual Status EnqueueCompactionOnStreamForHbm(
|
||||
stream_executor::Stream* compaction_stream) {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
@ -85,5 +85,9 @@ stream_executor::StreamExecutor* TpuNodeContext::stream_executor() const {
|
||||
return backend()->stream_executor(device_ordinal_).ValueOrDie();
|
||||
}
|
||||
|
||||
bool TpuNodeContext::CompactionSupported(int device_ordinal) const {
|
||||
return tpu::OpsApiFn()->TpuNodeContext_CompactionSupportedFn(device_ordinal);
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -67,6 +67,8 @@ class TpuNodeContext final {
|
||||
|
||||
stream_executor::StreamExecutor* stream_executor() const;
|
||||
|
||||
bool CompactionSupported(int device_ordinal) const;
|
||||
|
||||
private:
|
||||
const int device_ordinal_;
|
||||
XLA_TpuNodeContext* const node_context_;
|
||||
|
Loading…
Reference in New Issue
Block a user