Introduce TPUReshardVariable op to open source TensorFlow

PiperOrigin-RevId: 359089840
Change-Id: I7dd20d202e7881cb5c3ee0e5a52146d34e3df058
This commit is contained in:
Frank Chen 2021-02-23 11:17:50 -08:00 committed by TensorFlower Gardener
parent 31a8ae453c
commit 7c4e09623d
14 changed files with 835 additions and 3 deletions

View File

@ -35,6 +35,7 @@ tf_kernel_library(
":tpu_configuration_ops",
":tpu_execute_op",
":tpu_handle_to_key_op",
":tpu_reshard_variables_op",
":transfer_ops",
],
)
@ -278,9 +279,11 @@ cc_library(
hdrs = ["tpu_program_group_interface.h"],
deps = [
":tpu_compilation_cache_key",
":tpu_executable_info_proto_cc",
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core/lib/core:status",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
],
@ -883,3 +886,67 @@ cc_library(
"//tensorflow/core:framework",
],
)
cc_library(
name = "tpu_reshard_variables_op",
srcs = ["tpu_reshard_variables_op.cc"],
hdrs = ["tpu_reshard_variables_op.h"],
deps = [
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_lookup",
":tpu_op_consts",
":tpu_program_group",
":tpu_reshard_variables_op_util",
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/jit:xla_tensor",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/core/tpu:tpu_execute",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/tpu:tpu_executor_hdrs",
"//tensorflow/stream_executor/tpu:tpu_executor_interface",
"//tensorflow/stream_executor/tpu:tpu_node_context",
],
alwayslink = 1,
)
cc_library(
name = "tpu_reshard_variables_op_util",
srcs = ["tpu_reshard_variables_op_util.cc"],
hdrs = ["tpu_reshard_variables_op_util.h"],
deps = [
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_lookup",
":tpu_op_consts",
":tpu_program_group",
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/jit:xla_tensor",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/core/tpu:tpu_execute",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/tpu:tpu_executor_hdrs",
"//tensorflow/stream_executor/tpu:tpu_executor_interface",
"//tensorflow/stream_executor/tpu:tpu_node_context",
],
alwayslink = 1,
)

View File

@ -127,12 +127,12 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
const XLA_TpuProgram* tpu_program(int index) const;
const XLA_TpuProgram* tpu_program(int index) const override;
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
const TPUExecutableInfoProto& executable_info(int index) const;
const TPUExecutableInfoProto& executable_info(int index) const override;
const TPUHostTransferInfoProto& host_transfer_info(int index) const;
const TPUHostTransferInfoProto& host_transfer_info(int index) const override;
void set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas);
const xla::HloProto* hlo_metadata(int index) const;
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;

View File

@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {
@ -66,6 +68,16 @@ class TpuProgramGroupInterface {
// Gets may modify variables value of the TPU program for the given core
// `index`.
virtual bool may_modify_variables(int index) const = 0;
// Get Executable Info Proto
virtual const TPUExecutableInfoProto& executable_info(int index) const = 0;
// Get HostTransferInfo Proto
virtual const TPUHostTransferInfoProto& host_transfer_info(
int index) const = 0;
// Get XLA_TpuProgram Proto
virtual const XLA_TpuProgram* tpu_program(int index) const = 0;
};
} // namespace tpu

View File

@ -0,0 +1,283 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
#include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_execute.h"
#include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
namespace tensorflow {
namespace reshard_util = ::tensorflow::tpu::reshard_variables;
TPUReshardVariablesOpKernel::TPUReshardVariablesOpKernel(
OpKernelConstruction* context)
: AsyncOpKernel(context, /* is_deferred = */ true) {
OP_REQUIRES_OK(context, context->GetAttr("N", &num_vars_));
}
void TPUReshardVariablesOpKernel::ComputeAsync(OpKernelContext* context,
DoneCallback done) {
// If TPU launches are asynchronous, then perform the launch on this thread
// to avoid a thread hop, which has an observable latency cost.
OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
done();
}
Status TPUReshardVariablesOpKernel::DoWork(OpKernelContext* context) {
VLOG(1) << "Cloud TPU: TPUReshardVariablesOpKernel::DoWork";
TF_RET_CHECK(context->input_dtype(num_vars_) == DT_STRING);
const Tensor* new_format_key;
TF_RETURN_IF_ERROR(context->input("new_format_key", &new_format_key));
TF_RETURN_IF_ERROR(reshard_util::CheckIsValidKey(*new_format_key));
TF_RET_CHECK(context->input_dtype(num_vars_ + 1) == DT_RESOURCE);
const ResourceHandle& handle = HandleFromInput(context, num_vars_ + 1);
Var* format_state_var;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
context, handle, &format_state_var, [new_format_key](Var** ptr) {
*ptr = new Var(new_format_key->dtype());
return Status::OK();
}));
mutex_lock ml(*format_state_var->mu());
const bool initialized = format_state_var->is_initialized;
if (initialized) {
TF_RETURN_IF_ERROR(
reshard_util::CheckIsValidKey(*format_state_var->tensor()));
}
const bool state_is_default =
!initialized || reshard_util::IsDefaultKey(*format_state_var->tensor());
const bool new_format_is_default =
reshard_util::IsDefaultKey(*new_format_key);
if ((state_is_default && new_format_is_default) ||
(initialized && format_state_var->tensor()->vec<tstring>()(2) ==
new_format_key->vec<tstring>()(2))) {
VLOG(1) << "Sharding unchanged, nothing to do.";
return Status::OK();
}
if (!state_is_default) {
// Convert the current format to default (unsharded).
VLOG(1) << "Unsharding with key: "
<< format_state_var->tensor()->vec<tstring>()(2);
TF_RETURN_IF_ERROR(
DoTpuExecute(context, *format_state_var->tensor(),
tpu::CompilationCacheFetchTarget::UNSHARDING));
}
if (!new_format_is_default) {
// Convert the new format.
VLOG(1) << "Sharding with key: " << new_format_key->vec<tstring>()(2);
TF_RETURN_IF_ERROR(DoTpuExecute(
context, *new_format_key, tpu::CompilationCacheFetchTarget::SHARDING));
}
// Change the state.
*format_state_var->tensor() = *new_format_key;
format_state_var->is_initialized = true;
return Status::OK();
}
Status TPUReshardVariablesOpKernel::DoTpuExecute(
OpKernelContext* context, const Tensor& format_key,
tpu::CompilationCacheFetchTarget fetch_target) {
const XlaDevice::Metadata* metadata;
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
const int device_ordinal = metadata->device_ordinal();
// We are guaranteed that the underlying object won't be deleted out from
// under us
TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
tpu::TpuNodeContext::Create(device_ordinal));
profiler::TraceMe trace_me(
[device_ordinal] {
return profiler::TraceMeEncode("TPUReshardVariablesOpKernel",
{{"device_ordinal", device_ordinal}});
},
/*level=*/2);
profiler::TraceMe trace_me_init("TPUReshardVariablesOpKernel::Init",
/*level=*/2);
string rendezvous_key_base;
std::unique_ptr<tpu::CompilationCacheEntryRef> entry_ref;
TF_RETURN_IF_ERROR(reshard_util::GetComputationCacheEntry(
format_key, &rendezvous_key_base, &entry_ref, fetch_target));
tpu::TpuCompilationCacheEntry entry = entry_ref->get();
if (entry.tpu_program_group() == nullptr) {
VLOG(2) << "Sharding/unsharding program does not exist, so this is default "
"sharding.";
return Status::OK();
}
const tpu::TpuProgramGroupInterface* tpu_program_group =
entry.tpu_program_group();
const int core_index = entry.core_index();
const TPUExecutableInfoProto& executable_info_proto =
tpu_program_group->executable_info(core_index);
const TPUExecutableInfoProto* executable = &executable_info_proto;
xla::Backend* const backend = node_interfaces->backend();
xla::TransferManager* const transfer_manager = backend->transfer_manager();
CHECK(context->op_device_context());
se::Stream* stream = context->op_device_context()->stream();
TF_RET_CHECK(executable->input_shapes_size() == 1);
xla::Shape host_shape(executable->input_shapes(0));
std::vector<VariableInfo> variables;
for (int i = 0; i < num_vars_; ++i) {
TF_RET_CHECK(context->input_dtype(i) == DT_RESOURCE);
const ResourceHandle& handle = HandleFromInput(context, i);
Var* variable;
TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
variables.push_back(VariableInfo(i, handle.name(), variable));
}
// Block for previous TPUExecute ops so that the memory used for them could be
// freed.
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
// Lock variables to prevent concurrent access.
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
// Build input buffers.
TF_ASSIGN_OR_RETURN(auto input_buffers, reshard_util::BuildInputBuffers(
context, variables, host_shape,
backend, device_ordinal, stream));
xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(),
device_ordinal);
shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>(
[](xla::MaybeOwningDeviceMemory* buffer) {
CHECK(buffer);
return buffer->AsDeviceMemoryBase();
}));
// Write input root tuple.
TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
backend->BorrowStream(device_ordinal));
if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
shaped_buffer)) {
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
transfer_stream_ptr.get(), shaped_buffer));
stream->ThenWaitFor(transfer_stream_ptr.get());
} else {
TF_RETURN_IF_ERROR(
transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
}
VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
TF_RET_CHECK(!executable->has_session_module())
<< "session module not supported in sharding/unsharding program.";
auto definition_event = std::make_shared<se::Event>(stream->parent());
TF_RET_CHECK(definition_event->Init())
<< "TPU definition event initialization failed";
trace_me_init.Stop();
// Execute the program.
std::unique_ptr<xla::DeviceAssignment> device_assignment;
if (executable->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(
device_assignment,
xla::DeviceAssignment::Deserialize(executable->device_assignment()));
}
std::vector<xla::ExecutionInput> input;
input.emplace_back(xla::ExecutionInput(std::move(input_buffers),
shaped_buffer.on_host_shape()));
const TPUHostTransferInfoProto& host_transfer_info =
tpu_program_group->host_transfer_info(core_index);
TF_ASSIGN_OR_RETURN(
xla::ExecutionOutput output,
TPUExecute(*executable, host_transfer_info,
*tpu_program_group->hlo_metadatas()[core_index],
std::move(input), rendezvous_key_base, GetXLARandomSeed(),
node_interfaces.get(), device_assignment.get(),
context->cancellation_manager(), context, stream,
transfer_stream_ptr.get(),
tpu_program_group->tpu_program(core_index)));
stream->ThenRecordEvent(definition_event.get());
// Assign the new buffers to the variables.
xla::ScopedShapedBuffer result = output.ConsumeResult();
// Only perform compaction when sharding.
// NOTE: Compaction is not supported on some TPUs, see b/168322060 for details
if (node_interfaces->CompactionSupported(device_ordinal) &&
fetch_target == tpu::CompilationCacheFetchTarget::SHARDING) {
// Block until program execution is done so that input, output, and program
// cache memory can be actually released.
TF_RETURN_IF_ERROR(transfer_stream_ptr->BlockHostUntilDone());
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
{
// Explicitly release any RAII objects owning on-device allocations.
auto unused = output.ConsumeToBeReleased();
}
// Release variables holding inputs.
for (int i = 0; i < variables.size(); ++i) {
*variables[i].var()->tensor() = Tensor();
}
// Flush on-device program memory cache.
TF_RETURN_IF_ERROR(
reshard_util::FlushProgramMemory(backend->platform(), device_ordinal));
TF_RETURN_IF_ERROR(reshard_util::PerformCompaction(stream));
}
return reshard_util::UpdateOutputVariables(
context, std::move(result), executable->output_tensor_shapes(), backend,
stream, device_ordinal, variables, definition_event);
}
TPUReshardVariablesOpKernel::~TPUReshardVariablesOpKernel() = default;
#if !defined(PLATFORM_GOOGLE)
REGISTER_KERNEL_BUILDER(Name("TPUReshardVariables")
.Device(DEVICE_TPU_NODE)
.HostMemory("format_state_var")
.HostMemory("new_format_key"),
TPUReshardVariablesOpKernel);
#endif
} // namespace tensorflow

View File

@ -0,0 +1,52 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_
#include <memory>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
namespace tensorflow {
// Op that changes the sharding state for a set of on-device variables. The
// sharding state is represented as the key of the compilation that generated
// the sharding/unsharding programs along with the main program. The op checks
// if the current sharding state matches the desired one, and if not, uses the
// sharding/unsharding programs to transform the variables to the desired state.
class TPUReshardVariablesOpKernel : public AsyncOpKernel {
public:
explicit TPUReshardVariablesOpKernel(OpKernelConstruction* context);
~TPUReshardVariablesOpKernel() override;
void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
private:
Status DoWork(OpKernelContext* context);
Status DoTpuExecute(OpKernelContext* context, const Tensor& format_key,
tpu::CompilationCacheFetchTarget fetch_target);
int64 num_vars_;
DISALLOW_COPY_AND_ASSIGN(TPUReshardVariablesOpKernel);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_

View File

@ -0,0 +1,314 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_execute.h"
#include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
namespace tensorflow {
namespace tpu {
namespace reshard_variables {
Status FlushProgramMemory(se::Platform* platform, int device_ordinal) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
tpu::TpuNodeContext::Create(device_ordinal));
auto* executor = tensorflow::down_cast<tpu::TpuExecutorInterface*>(
node_interfaces->stream_executor()->implementation());
return executor->UnloadAllPrograms();
}
Status CheckIsValidKey(const Tensor& key) {
if (!TensorShapeUtils::IsVector(key.shape()) ||
key.shape().dim_size(0) != 3) {
return errors::InvalidArgument(
"new_format_key argument to TPUReshardVariables must be a 3-element "
"vector");
}
if (key.dtype() != DT_STRING) {
return errors::InvalidArgument(
"new_format_key argument to TPUReshardVariables must be DT_STRING "
"type");
}
return Status::OK();
}
bool IsDefaultKey(const Tensor& key) { return key.vec<tstring>()(0).empty(); }
// Looks up the input `key` in the compilation cache, populating
// `*rendezvous_key_base` and `*entry`.
Status GetComputationCacheEntry(
const Tensor& key, string* rendezvous_key_base,
std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
tpu::CompilationCacheFetchTarget fetch_target) {
profiler::TraceMe trace_me("TPUReshardVariablesOpKernel::LookupProto",
/*level=*/2);
TF_RETURN_IF_ERROR(CheckIsValidKey(key));
auto* rmgr = GetTPUConfigResourceMgr();
tpu::TpuCompilationCacheLookup* proto_lookup;
TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
tpu::kCompiledProtoCacheResourceName,
&proto_lookup));
core::ScopedUnref lookup_unref(proto_lookup);
TF_RETURN_IF_ERROR(
proto_lookup->Lookup(key.vec<tstring>()(0), entry, fetch_target));
*rendezvous_key_base = key.vec<tstring>()(1);
return Status::OK();
}
// Builds an InputBuffers object that describes the inputs to the computation.
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>> BuildInputBuffers(
OpKernelContext* context, const std::vector<VariableInfo>& variables,
const xla::Shape& input_host_shape, xla::Backend* backend,
int device_ordinal, se::Stream* stream) {
profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
OpInputList var_list;
TF_RETURN_IF_ERROR(context->input_list("vars", &var_list));
if (var_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
return errors::InvalidArgument(
"Number of variables (", var_list.size(),
") does not match input shape: ",
xla::ShapeUtil::TupleElementCount(input_host_shape));
}
auto validate_shape = [&](int i, const Tensor& tensor) {
const xla::Shape& expected =
xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
// FromTensor failed; tensor must be empty.
if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
return errors::InvalidArgument(
"Run-time shape mismatch for TPUExecute argument[", i, "] (",
context->op_kernel().requested_input(i), "). Expected ",
expected.DebugString(),
"; got empty tensor. If you are running "
"with TF2 TPU, make sure you set `drop_remainder=False` when "
"calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
"size can be handled");
}
} else {
const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
return errors::InvalidArgument(
"Run-time shape mismatch for TPUReshardVariables argument[", i,
"] (", context->op_kernel().requested_input(i), "). Expected ",
expected.DebugString(), "; got ", xla_shape.DebugString());
}
}
return Status::OK();
};
for (int i = 0; i < variables.size(); ++i) {
TF_RETURN_IF_ERROR(
validate_shape(variables[i].index(), *variables[i].var()->tensor()));
}
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
xla::TransferManager* const transfer_manager = backend->transfer_manager();
xla::ShapeTree<xla::MaybeOwningDeviceMemory> input_buffers(
transfer_manager->HostShapeToDeviceShape(input_host_shape));
// Allocates a buffer for the root tuple.
const int64 root_size =
transfer_manager->GetByteSizeRequirement(input_buffers.shape());
TF_ASSIGN_OR_RETURN(*input_buffers.mutable_element({}),
allocator->Allocate(device_ordinal, root_size));
auto set_input_buffers_helper = [&](int arg_index, xla::ShapedBuffer* buffers,
bool owning = false) {
buffers->buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
xla::ShapeIndex in_index = {arg_index};
for (int64 j : index) {
in_index.push_back(j);
}
if (owning) {
*input_buffers.mutable_element(in_index) =
se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
*buffer = se::DeviceMemoryBase();
} else {
*input_buffers.mutable_element(in_index) = *buffer;
}
});
};
// Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
// buffers for zero-element tensors where required.
auto assign_input = [&](int i, const Tensor& tensor) -> xla::Status {
XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
// Size 0 tensors have no backing XlaTensor, but may still need to have
// tuple buffers allocated.
if (xla_tensor == nullptr) {
CHECK_EQ(tensor.NumElements(), 0);
const xla::Shape& host_shape =
xla::ShapeUtil::GetSubshape(input_host_shape, {i});
TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
transfer_manager->AllocateScopedShapedBuffer(
host_shape, allocator, device_ordinal));
set_input_buffers_helper(/*arg_index=*/i, &buffers);
} else {
set_input_buffers_helper(/*arg_index=*/i, &xla_tensor->shaped_buffer(),
tensor.RefCountIsOne());
xla_tensor->WaitForDefinitionEventOnStream(stream);
}
return Status::OK();
};
for (int i = 0; i < var_list.size(); ++i) {
TF_RET_CHECK(var_list[i].dtype() == DT_RESOURCE);
TF_RETURN_IF_ERROR(assign_input(i, *variables[i].var()->tensor()));
}
return std::move(input_buffers);
}
// Perform a compaction to reduce fragmentation.
Status PerformCompaction(stream_executor::Stream* stream) {
profiler::TraceMe trace_me("PerformCompaction", /*level=*/2);
auto* ds_executor =
down_cast<tpu::TpuExecutorInterface*>(stream->parent()->implementation());
TF_RETURN_IF_ERROR(ds_executor->EnqueueCompactionOnStreamForHbm(stream));
// LoadProgram and GetOrCreateConstantHandle are not managed by stream
// dependencies but they write to shared memory, so we need to block here to
// prevent those operations from racing.
return stream->BlockHostUntilDone();
}
// Updates the variables to the execution result's buffers, and deallocates the
// root tuple buffer.
Status UpdateOutputVariables(
OpKernelContext* context, xla::ScopedShapedBuffer result_buffers,
absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
xla::Backend* backend, se::Stream* stream, int device_ordinal,
const std::vector<VariableInfo>& variables,
const std::shared_ptr<se::Event>& definition_event) {
profiler::TraceMe trace_me("UpdateOutputVariables", /*level=*/2);
// Shapes of the outputs, in TensorShape form.
const int64 sub_elements =
xla::ShapeUtil::TupleElementCount(result_buffers.on_host_shape());
if (sub_elements != output_tensor_shape_protos.size()) {
return errors::InvalidArgument(
"Mismatched numbers of output shapes: ", sub_elements, " vs. ",
output_tensor_shape_protos.size());
}
if (sub_elements != variables.size()) {
return errors::InvalidArgument(
"Output count does not equal varaible count: ", sub_elements, " vs. ",
variables.size());
}
std::vector<TensorShape> output_tensor_shapes;
output_tensor_shapes.reserve(sub_elements);
for (int64 i = 0; i < sub_elements; ++i) {
TF_RETURN_IF_ERROR(
TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
TensorShape shape(*output_tensor_shape_protos[i]);
const xla::Shape& xla_shape =
xla::ShapeUtil::GetSubshape(result_buffers.on_host_shape(), {i});
if (!xla_shape.IsArray() ||
xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
return errors::InvalidArgument(
"Mismatched number of elements in output shape: ",
xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
}
output_tensor_shapes.push_back(shape);
VLOG(2) << "Output " << i << " shape " << shape.DebugString();
}
// Build a shaped buffer for the outputs.
TF_RET_CHECK(result_buffers.on_host_shape().IsTuple());
TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(result_buffers.on_host_shape()));
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
auto output_buffers = result_buffers.release();
const xla::Shape& output_host_shape = output_buffers.on_host_shape();
const xla::Shape& output_device_shape = output_buffers.on_device_shape();
// Transfers ownership of the buffers that back XLA computation output 'i'
// to 'output_tensor'.
auto transfer_buffers = [&](int i, Tensor* output_tensor) {
const xla::Shape& host_shape =
xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
const xla::Shape& device_shape =
xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
if (output_tensor->NumElements() > 0) {
xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
device_ordinal);
shaped_buffer.buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
xla::ShapeIndex out_index = {i};
for (int64 j : index) {
out_index.push_back(j);
}
*buffer = output_buffers.buffers().element(out_index);
});
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
xla_tensor->ResetDefinitionEvent(definition_event, stream);
}
};
for (int i = 0; i < variables.size(); ++i) {
PersistentTensor unused;
Tensor* output_tensor;
TF_RETURN_IF_ERROR(context->allocate_persistent(
variables[i].var()->tensor()->dtype(), output_tensor_shapes[i], &unused,
&output_tensor));
*variables[i].var()->tensor() = *output_tensor;
transfer_buffers(i, output_tensor);
}
return allocator->Deallocate(output_buffers.device_ordinal(),
output_buffers.buffer({}));
}
} // namespace reshard_variables
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,61 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_
#include <memory>
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
namespace tensorflow {
namespace tpu {
namespace reshard_variables {
Status FlushProgramMemory(se::Platform* platform, int device_ordinal);
Status CheckIsValidKey(const Tensor& key);
bool IsDefaultKey(const Tensor& key);
Status GetComputationCacheEntry(
const Tensor& key, string* rendezvous_key_base,
std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
tpu::CompilationCacheFetchTarget fetch_target);
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>> BuildInputBuffers(
OpKernelContext* context, const std::vector<VariableInfo>& variables,
const xla::Shape& input_host_shape, xla::Backend* backend,
int device_ordinal, se::Stream* stream);
Status PerformCompaction(stream_executor::Stream* stream);
Status UpdateOutputVariables(
OpKernelContext* context, xla::ScopedShapedBuffer result_buffers,
absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
xla::Backend* backend, se::Stream* stream, int device_ordinal,
const std::vector<VariableInfo>& variables,
const std::shared_ptr<se::Event>& definition_event);
} // namespace reshard_variables
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_

View File

@ -426,6 +426,8 @@ void TpuNodeContext_CloseTpuHost(TF_Status* status);
void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status);
bool TpuNodeContext_CompactionSupported(int device_ordinal);
// Globally initialize the TPU system for inference.
TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer();
@ -495,6 +497,7 @@ struct TfTpu_OpsApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported);
TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer);
};

View File

@ -327,6 +327,21 @@ bool TpuExecutor::MemcpyDeviceToDevice(
LOG(FATAL) << __func__ << " not supported on TpuExecutor";
}
Status TpuExecutor::UnloadAllPrograms() {
StatusHelper status;
tpu::ExecutorApiFn()->TpuExecutor_UnloadAllProgramsFn(executor_,
status.c_status);
return status.status();
}
Status TpuExecutor::EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) {
StatusHelper status;
tpu::ExecutorApiFn()->TpuExecutor_EnqueueCompactionOnStreamForHbmFn(
executor_, get_stream(compaction_stream->implementation()),
status.c_status);
return status.status();
}
struct HostCallbackContext {
std::function<Status()> callback;
};

View File

@ -155,6 +155,10 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
Status WaitForOutfeedReady(int32 outfeed_queue_index);
Status UnloadAllPrograms() override;
Status EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) override;
const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
return *platform_;
}

View File

@ -122,6 +122,12 @@ void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor,
void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor);
bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor);
void TpuExecutor_UnloadAllPrograms(SE_StreamExecutor* executor,
TF_Status* status);
void TpuExecutor_EnqueueCompactionOnStreamForHbm(SE_StreamExecutor* executor,
SE_Stream* compaction_stream,
TF_Status* status);
SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
void TpuStream_Free(SE_Stream*);
void* TpuStream_Stream(SE_Stream*);
@ -389,6 +395,8 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockUntilDoneOrFailed);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SyncAndForgetFailedStreams);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronizeAllActivity);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_UnloadAllPrograms);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_EnqueueCompactionOnStreamForHbm);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_New);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Free);

View File

@ -58,6 +58,13 @@ class TpuExecutorInterface
virtual TpuCoreLocationExternal GetCoreLocationExternal() const {
LOG(FATAL) << "Unimplemented.";
}
virtual Status UnloadAllPrograms() { LOG(FATAL) << "Unimplemented."; }
virtual Status EnqueueCompactionOnStreamForHbm(
stream_executor::Stream* compaction_stream) {
LOG(FATAL) << "Unimplemented.";
}
};
} // namespace tpu

View File

@ -85,5 +85,9 @@ stream_executor::StreamExecutor* TpuNodeContext::stream_executor() const {
return backend()->stream_executor(device_ordinal_).ValueOrDie();
}
bool TpuNodeContext::CompactionSupported(int device_ordinal) const {
return tpu::OpsApiFn()->TpuNodeContext_CompactionSupportedFn(device_ordinal);
}
} // namespace tpu
} // namespace tensorflow

View File

@ -67,6 +67,8 @@ class TpuNodeContext final {
stream_executor::StreamExecutor* stream_executor() const;
bool CompactionSupported(int device_ordinal) const;
private:
const int device_ordinal_;
XLA_TpuNodeContext* const node_context_;