Add a TPU execution op.
PiperOrigin-RevId: 321887408 Change-Id: I8f4686189e98c28da00f56f7530a95bce2a2f17c
This commit is contained in:
parent
1cbaa0c469
commit
19448cf8b9
@ -553,3 +553,44 @@ cc_library(
|
|||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_execute_op",
|
||||||
|
srcs = ["tpu_execute_op.cc"],
|
||||||
|
hdrs = ["tpu_execute_op.h"],
|
||||||
|
deps = [
|
||||||
|
":tpu_compilation_cache_entry",
|
||||||
|
":tpu_compilation_cache_external",
|
||||||
|
":tpu_compilation_cache_local_lookup",
|
||||||
|
":tpu_compilation_cache_lookup",
|
||||||
|
":tpu_executable_info_proto_cc",
|
||||||
|
":tpu_op_consts",
|
||||||
|
"//tensorflow/compiler/jit:xla_device",
|
||||||
|
"//tensorflow/compiler/jit:xla_launch_util",
|
||||||
|
"//tensorflow/compiler/jit:xla_tensor",
|
||||||
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/service:dump",
|
||||||
|
"//tensorflow/compiler/xla/service:executable",
|
||||||
|
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"//tensorflow/core/tpu:tpu_configuration",
|
||||||
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
|
"//tensorflow/core/tpu:tpu_execute",
|
||||||
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
809
tensorflow/core/tpu/kernels/tpu_execute_op.cc
Normal file
809
tensorflow/core/tpu/kernels/tpu_execute_op.cc
Normal file
@ -0,0 +1,809 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/dump.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/executable.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/resource_var.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
#include "tensorflow/core/platform/tracing.h"
|
||||||
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_execute.h"
|
||||||
|
#include "tensorflow/core/util/stream_executor_util.h"
|
||||||
|
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::tensorflow::tpu::TpuNodeContext;
|
||||||
|
using CompilationCacheEntryRef = ::tensorflow::tpu::CompilationCacheEntryRef<
|
||||||
|
::tensorflow::tpu::TpuCompilationCacheEntry>;
|
||||||
|
using TpuCompilationCacheLookup =
|
||||||
|
::tensorflow::tpu::TpuCompilationCacheLookup<CompilationCacheEntryRef>;
|
||||||
|
|
||||||
|
// Looks up the input `key` in the compilation cache, populating
|
||||||
|
// `*rendezvous_key_base` and `*entry`.
|
||||||
|
Status GetComputationCacheEntry(
|
||||||
|
OpKernelContext* context, string* rendezvous_key_base,
|
||||||
|
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||||
|
const Tensor* key;
|
||||||
|
TF_RETURN_IF_ERROR(context->input("key", &key));
|
||||||
|
profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
|
||||||
|
if (!TensorShapeUtils::IsVector(key->shape()) ||
|
||||||
|
key->shape().dim_size(0) != 2) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Key argument to TPUExecute must be a 2-element vector");
|
||||||
|
}
|
||||||
|
|
||||||
|
ResourceMgr* rmgr = GetTPUConfigResourceMgr();
|
||||||
|
TpuCompilationCacheLookup* proto_lookup;
|
||||||
|
TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
|
||||||
|
tpu::kCompiledProtoCacheResourceName,
|
||||||
|
&proto_lookup));
|
||||||
|
core::ScopedUnref lookup_unref(proto_lookup);
|
||||||
|
TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
|
||||||
|
*rendezvous_key_base = key->vec<tstring>()(1);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct VariableUpdateMap {
|
||||||
|
// Maps input index to the updated output index. If the variable doesn't have
|
||||||
|
// an updated output, the corresponding output is set to -1.
|
||||||
|
absl::flat_hash_map<int, int> input_to_output;
|
||||||
|
// Maps output index to (the input index, whether the update is generated from
|
||||||
|
// compilation).
|
||||||
|
absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
|
||||||
|
// Part of the input indices that are from the compilation, in the compiled
|
||||||
|
// order.
|
||||||
|
std::vector<int> input_in_compiled_update_order;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates a VariableUpdateMap from both the compilation and the fused variable
|
||||||
|
// reads/updates.
|
||||||
|
xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
|
||||||
|
absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
|
||||||
|
compiled_variable_updates,
|
||||||
|
absl::Span<int const> fused_device_var_reads_in_computation_inputs,
|
||||||
|
const std::vector<int>& fused_device_var_updates_in_computation_outputs,
|
||||||
|
int64 computation_output_count) {
|
||||||
|
VariableUpdateMap map;
|
||||||
|
auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
|
||||||
|
TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
|
||||||
|
<< "Duplicate variable input index: " << input;
|
||||||
|
if (output >= 0) {
|
||||||
|
TF_RET_CHECK(map.output_to_input
|
||||||
|
.emplace(output, std::make_pair(input, from_compilation))
|
||||||
|
.second)
|
||||||
|
<< "Duplicate variable output index: " << output;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
// First add the updates produced by the compilation. Not all variables are
|
||||||
|
// updated, and if not, they do not have an output in the XLA computation. The
|
||||||
|
// update output indices in the XLA computation start after the non-variable
|
||||||
|
// outputs.
|
||||||
|
int num_updated_variables = 0;
|
||||||
|
for (int i = 0; i < compiled_variable_updates.size(); ++i) {
|
||||||
|
const bool updated = compiled_variable_updates[i]->updated();
|
||||||
|
if (updated) ++num_updated_variables;
|
||||||
|
}
|
||||||
|
TF_RET_CHECK(num_updated_variables <= computation_output_count)
|
||||||
|
<< num_updated_variables << " <= " << computation_output_count;
|
||||||
|
int64 compiled_variable_output_index =
|
||||||
|
computation_output_count - num_updated_variables;
|
||||||
|
for (auto update : compiled_variable_updates) {
|
||||||
|
map.input_in_compiled_update_order.push_back(update->index());
|
||||||
|
if (!update->updated()) {
|
||||||
|
TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
add_pair(update->index(), compiled_variable_output_index, true));
|
||||||
|
++compiled_variable_output_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now add the updates from the attributes.
|
||||||
|
TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
|
||||||
|
fused_device_var_updates_in_computation_outputs.size());
|
||||||
|
for (int64 i = 0; i < fused_device_var_reads_in_computation_inputs.size();
|
||||||
|
++i) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
add_pair(fused_device_var_reads_in_computation_inputs[i],
|
||||||
|
fused_device_var_updates_in_computation_outputs[i], false));
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffers representing the inputs to a computation.
|
||||||
|
struct InputBuffers {
|
||||||
|
explicit InputBuffers(xla::Shape device_shape)
|
||||||
|
: buffers(std::move(device_shape)) {}
|
||||||
|
|
||||||
|
InputBuffers(const InputBuffers&) = delete;
|
||||||
|
InputBuffers& operator=(const InputBuffers&) = delete;
|
||||||
|
|
||||||
|
~InputBuffers() = default;
|
||||||
|
|
||||||
|
xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
|
||||||
|
se::DeviceMemoryAllocator* allocator,
|
||||||
|
int device_ordinal) {
|
||||||
|
CHECK_NE(allocator, nullptr);
|
||||||
|
xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
|
||||||
|
allocator->platform(), device_ordinal);
|
||||||
|
shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
|
||||||
|
[](xla::MaybeOwningDeviceMemory* buffer) {
|
||||||
|
CHECK(buffer);
|
||||||
|
return buffer->AsDeviceMemoryBase();
|
||||||
|
}));
|
||||||
|
return shaped_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes the buffer tree.
|
||||||
|
xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
|
||||||
|
|
||||||
|
// Information about resource variables passed directly to TPUExecute.
|
||||||
|
std::vector<VariableInfo> variables;
|
||||||
|
|
||||||
|
// Mapping from input index to offsets in 'variables'. < 0 if the input does
|
||||||
|
// not correspond to a variable in 'variables'.
|
||||||
|
std::vector<int> variable_index;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Builds an InputBuffers object that describes the inputs to the computation.
|
||||||
|
xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
|
||||||
|
OpKernelContext* context, const xla::Shape& input_host_shape,
|
||||||
|
const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
|
||||||
|
se::Stream* stream) {
|
||||||
|
profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
|
||||||
|
OpInputList arg_list;
|
||||||
|
TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
|
||||||
|
|
||||||
|
if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Number of parameters (", arg_list.size(),
|
||||||
|
") does not match input shape: ",
|
||||||
|
xla::ShapeUtil::TupleElementCount(input_host_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto validate_shape = [&](int i, const Tensor& tensor) {
|
||||||
|
const xla::Shape& expected =
|
||||||
|
xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
|
||||||
|
VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
|
||||||
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
|
||||||
|
|
||||||
|
if (xla_tensor == nullptr) {
|
||||||
|
// FromTensor failed; tensor must be empty.
|
||||||
|
if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Run-time shape mismatch for TPUExecute argument[", i, "] (",
|
||||||
|
context->op_kernel().requested_input(i), "). Expected ",
|
||||||
|
expected.DebugString(), "; got empty tensor");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Compare host shapes, easier than getting the expected device shape.
|
||||||
|
const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
|
||||||
|
if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Run-time shape mismatch for TPUExecute argument[", i, "] (",
|
||||||
|
context->op_kernel().requested_input(i), "). Expected ",
|
||||||
|
expected.DebugString(), "; got ", xla_shape.DebugString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Iterate over the inputs, validating the shapes of non-variable inputs,
|
||||||
|
// and creating a VariableInfo object for each variable. We consider variable
|
||||||
|
// inputs in a separate phase because we must acquire variable locks in order.
|
||||||
|
std::vector<VariableInfo> variables;
|
||||||
|
std::vector<int> variable_index(arg_list.size(), -1);
|
||||||
|
variables.reserve(arg_list.size());
|
||||||
|
for (int i = 0; i < arg_list.size(); ++i) {
|
||||||
|
// Arguments are assumed to be variables if they have a resource type.
|
||||||
|
// (Non-variable resources are not supported.)
|
||||||
|
if (context->input_dtype(i) == DT_RESOURCE) {
|
||||||
|
variable_index[i] = variables.size();
|
||||||
|
// TODO(phawkins): we may be looking up many variables here; it would be
|
||||||
|
// better if we did not repeatedly acquire the resource manager's lock.
|
||||||
|
const ResourceHandle& handle = HandleFromInput(context, i);
|
||||||
|
Var* variable;
|
||||||
|
TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
|
||||||
|
variables.push_back(VariableInfo(i, handle.name(), variable));
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock the variables, and validate their shapes. We hold the variable locks
|
||||||
|
// for the duration of the TPU execution so we can donate the variable buffers
|
||||||
|
// to the computation. If we copied the variable's Tensor instead, its
|
||||||
|
// reference count would be greater than one due to the reference the Var
|
||||||
|
// object holds, and we would never be able to reuse variable buffers.
|
||||||
|
// TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
|
||||||
|
// the user to elect to copy the buffers and permit concurrent access instead.
|
||||||
|
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
|
||||||
|
for (int i = 0; i < variables.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
validate_shape(variables[i].index(), *variables[i].var()->tensor()));
|
||||||
|
}
|
||||||
|
|
||||||
|
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
|
||||||
|
xla::TransferManager* const transfer_manager =
|
||||||
|
node_context->transfer_manager();
|
||||||
|
const int device_ordinal = node_context->device_ordinal();
|
||||||
|
|
||||||
|
auto input_buffers = absl::make_unique<InputBuffers>(
|
||||||
|
transfer_manager->HostShapeToDeviceShape(input_host_shape));
|
||||||
|
|
||||||
|
// Allocates a buffer for the root tuple.
|
||||||
|
const int64 root_size =
|
||||||
|
transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
|
||||||
|
TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
|
||||||
|
allocator->Allocate(device_ordinal, root_size));
|
||||||
|
|
||||||
|
// Helper function that sets the input buffers for 'arg_index' to 'buffers'.
|
||||||
|
// If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
|
||||||
|
// to the computation and overwrites the entries in 'buffers' with nulls.
|
||||||
|
auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
|
||||||
|
xla::ShapedBuffer* buffers) {
|
||||||
|
buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
|
||||||
|
se::DeviceMemoryBase* buffer) {
|
||||||
|
xla::ShapeIndex in_index = {arg_index};
|
||||||
|
for (int64 j : index) {
|
||||||
|
in_index.push_back(j);
|
||||||
|
}
|
||||||
|
auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
|
||||||
|
if (donate_buffers) {
|
||||||
|
*in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
|
||||||
|
*buffer = se::DeviceMemoryBase();
|
||||||
|
} else {
|
||||||
|
*in_buffer = *buffer;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
|
||||||
|
// buffers for zero-element tensors where required.
|
||||||
|
auto assign_input = [&](int i, const Tensor& tensor,
|
||||||
|
bool may_reuse) -> xla::Status {
|
||||||
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
|
||||||
|
|
||||||
|
// Size 0 tensors have no backing XlaTensor, but may still need to have
|
||||||
|
// tuple buffers allocated.
|
||||||
|
if (xla_tensor == nullptr) {
|
||||||
|
CHECK_EQ(tensor.NumElements(), 0);
|
||||||
|
const xla::Shape& host_shape =
|
||||||
|
xla::ShapeUtil::GetSubshape(input_host_shape, {i});
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
|
||||||
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
|
host_shape, allocator, device_ordinal));
|
||||||
|
set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
|
||||||
|
&buffers);
|
||||||
|
} else {
|
||||||
|
bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
|
||||||
|
set_input_buffers_helper(/*arg_index=*/i,
|
||||||
|
/*donate_buffers=*/can_reuse_buffers,
|
||||||
|
&xla_tensor->shaped_buffer());
|
||||||
|
xla_tensor->WaitForDefinitionEventOnStream(stream);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < arg_list.size(); ++i) {
|
||||||
|
auto it = variable_updates.input_to_output.find(i);
|
||||||
|
if (it == variable_updates.input_to_output.end()) {
|
||||||
|
TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// input i is a variable
|
||||||
|
bool updated = it->second >= 0;
|
||||||
|
if (arg_list[i].dtype() != DT_RESOURCE) {
|
||||||
|
TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
|
||||||
|
} else {
|
||||||
|
int vi = variable_index[i];
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
assign_input(i, *variables[vi].var()->tensor(), updated));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input_buffers->variables = std::move(variables);
|
||||||
|
input_buffers->variable_index = std::move(variable_index);
|
||||||
|
|
||||||
|
return std::move(input_buffers);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct OutputBuffers {
|
||||||
|
OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
|
||||||
|
: owned_buffers(b.on_device_shape(), true),
|
||||||
|
buffers(b.release()),
|
||||||
|
memory_allocator(allocator) {}
|
||||||
|
|
||||||
|
~OutputBuffers() {
|
||||||
|
buffers.buffers().ForEachElement([&](const xla::ShapeIndex& index,
|
||||||
|
const se::DeviceMemoryBase& buffer) {
|
||||||
|
if (owned_buffers.element(index) && !buffer.is_null()) {
|
||||||
|
Status status =
|
||||||
|
memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Error deallocating buffer " << status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Which of the buffers do we own?
|
||||||
|
xla::ShapeTree<bool> owned_buffers;
|
||||||
|
|
||||||
|
xla::ShapedBuffer buffers;
|
||||||
|
|
||||||
|
se::DeviceMemoryAllocator* const memory_allocator;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Allocates Tensors for the outputs of the computation. Ownership of most
|
||||||
|
// output buffers is passed to the output Tensors. Returns an OutputBuffer that
|
||||||
|
// owns the root buffer that should be passed to the XLA computation, as well as
|
||||||
|
// any output buffers that do not have corresponding output tensors. The latter
|
||||||
|
// may happen for zero-element tensors of type int64 or complex64 which still
|
||||||
|
// require a tuple buffer but do not have a corresponding XlaTensor.
|
||||||
|
xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
|
||||||
|
OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
|
||||||
|
absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
|
||||||
|
const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
|
||||||
|
se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
|
||||||
|
const std::shared_ptr<se::Event>& definition_event) {
|
||||||
|
VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
|
||||||
|
|
||||||
|
profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
|
||||||
|
// Shapes of the outputs, in TensorShape form.
|
||||||
|
const int64 sub_elements =
|
||||||
|
xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
|
||||||
|
if (sub_elements != output_tensor_shape_protos.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Mismatched numbers of output shapes: ", sub_elements, " vs. ",
|
||||||
|
output_tensor_shape_protos.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::TransferManager* const transfer_manager =
|
||||||
|
node_context->transfer_manager();
|
||||||
|
|
||||||
|
std::vector<TensorShape> output_tensor_shapes;
|
||||||
|
output_tensor_shapes.reserve(sub_elements);
|
||||||
|
for (int64 i = 0; i < sub_elements; ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
|
||||||
|
TensorShape shape(*output_tensor_shape_protos[i]);
|
||||||
|
const xla::Shape& xla_shape =
|
||||||
|
xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
|
||||||
|
if (!xla_shape.IsArray() ||
|
||||||
|
xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Mismatched number of elements in output shape: ",
|
||||||
|
xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
|
||||||
|
}
|
||||||
|
output_tensor_shapes.push_back(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds a shaped buffer for the outputs.
|
||||||
|
TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
|
||||||
|
TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
|
||||||
|
|
||||||
|
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
|
||||||
|
|
||||||
|
auto output_buffers =
|
||||||
|
absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
|
||||||
|
|
||||||
|
xla::Shape output_host_shape = output_buffers->buffers.on_host_shape();
|
||||||
|
xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
|
||||||
|
|
||||||
|
if (!output_host_shape.is_static()) {
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||||
|
stream, &output_buffers->buffers, &output_host_shape,
|
||||||
|
&output_device_shape));
|
||||||
|
for (int64 i = 0; i < sub_elements; ++i) {
|
||||||
|
const xla::Shape& subshape =
|
||||||
|
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||||
|
output_tensor_shapes[i] = shape;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transfers ownership of the buffers that back XLA computation output 'i'
|
||||||
|
// to 'output_tensor'.
|
||||||
|
auto transfer_buffers = [&](int i, Tensor* output_tensor) {
|
||||||
|
const xla::Shape& host_shape =
|
||||||
|
xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
|
||||||
|
const xla::Shape& device_shape =
|
||||||
|
xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
|
||||||
|
|
||||||
|
// Transfers ownership of the output buffers to the output Tensor, if
|
||||||
|
// there the tensor is backed by an XlaTensor. Tensors of size 0 have no
|
||||||
|
// backing XlaTensor, so we let retain 'output_buffers' ownership of any
|
||||||
|
// buffers in that case.
|
||||||
|
if (output_tensor->NumElements() > 0) {
|
||||||
|
xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
|
||||||
|
device_ordinal);
|
||||||
|
shaped_buffer.buffers().ForEachMutableElement(
|
||||||
|
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||||
|
xla::ShapeIndex out_index = {i};
|
||||||
|
for (int64 j : index) {
|
||||||
|
out_index.push_back(j);
|
||||||
|
}
|
||||||
|
*buffer = output_buffers->buffers.buffers().element(out_index);
|
||||||
|
*output_buffers->owned_buffers.mutable_element(out_index) = false;
|
||||||
|
});
|
||||||
|
|
||||||
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||||
|
xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
|
||||||
|
xla_tensor->ResetDefinitionEvent(definition_event, stream);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const int num_updated_variables = variable_updates.output_to_input.size();
|
||||||
|
TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
|
||||||
|
<< num_updated_variables << " <= " << output_tensor_shapes.size();
|
||||||
|
|
||||||
|
OpInputList arg_list;
|
||||||
|
TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
|
||||||
|
|
||||||
|
// The TPU program outputs the updated variables including DT_RESOURCE and
|
||||||
|
// non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
|
||||||
|
// variables (updated or not).
|
||||||
|
//
|
||||||
|
// updated not_updated
|
||||||
|
// |------------------|------------------|
|
||||||
|
// DT_RESOURCE | allocate persist | do nothing |
|
||||||
|
// |------------------|------------------|
|
||||||
|
// | allocate | forward Op input |
|
||||||
|
// not DT_RESOURCE | output | to Op output | Op output
|
||||||
|
// |------------------|------------------|
|
||||||
|
// program output
|
||||||
|
|
||||||
|
// Allocates a fresh tensor for each updated variable. While the variable
|
||||||
|
// inputs need come in no particular order, the variable values are
|
||||||
|
// always added last by XlaCompiler class, in the same order as the
|
||||||
|
// corresponding input variables.
|
||||||
|
int op_output_index = 0;
|
||||||
|
int compiled_update_index = 0;
|
||||||
|
auto process_non_updated_variable = [&](int input_index) {
|
||||||
|
const int variable_index = input_buffers->variable_index.at(input_index);
|
||||||
|
// If a DT_RESOURCE input is not updated, nothing needs to be done
|
||||||
|
// because there is no corresponding output. If a non-resource input
|
||||||
|
// is not updated, forward the input to the output.
|
||||||
|
if (variable_index < 0) {
|
||||||
|
context->set_output(op_output_index, arg_list[input_index]);
|
||||||
|
++op_output_index;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (int i = 0; i < output_tensor_shapes.size(); ++i) {
|
||||||
|
auto it = variable_updates.output_to_input.find(i);
|
||||||
|
if (it == variable_updates.output_to_input.end()) {
|
||||||
|
// Not a variable update.
|
||||||
|
// Allocates a fresh tensor for each output of the operator. We always
|
||||||
|
// allocate a new host-side tensor, but the on-device buffers that back
|
||||||
|
// that tensor may be aliases of input buffers.
|
||||||
|
Tensor* output_tensor;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||||
|
op_output_index, output_tensor_shapes[i], &output_tensor));
|
||||||
|
transfer_buffers(i, output_tensor);
|
||||||
|
++op_output_index;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int input_index = it->second.first;
|
||||||
|
// We must process the compiled updates in order, which includes the
|
||||||
|
// non-updated variables, i.e., those without an XLA output.
|
||||||
|
const bool from_compilation = it->second.second;
|
||||||
|
while (from_compilation &&
|
||||||
|
variable_updates
|
||||||
|
.input_in_compiled_update_order[compiled_update_index] !=
|
||||||
|
input_index) {
|
||||||
|
process_non_updated_variable(
|
||||||
|
variable_updates
|
||||||
|
.input_in_compiled_update_order[compiled_update_index]);
|
||||||
|
++compiled_update_index;
|
||||||
|
}
|
||||||
|
++compiled_update_index;
|
||||||
|
const int variable_index = input_buffers->variable_index.at(input_index);
|
||||||
|
PersistentTensor unused;
|
||||||
|
Tensor* output_tensor;
|
||||||
|
if (variable_index >= 0) {
|
||||||
|
// This output corresponds to a DT_RESOURCE input to the TPUExecute
|
||||||
|
// operator. Update the corresponding variable.
|
||||||
|
VariableInfo& var = input_buffers->variables[variable_index];
|
||||||
|
// TODO(b/35625933): the correct thing to do would be to transfer
|
||||||
|
// ownership of the PersistentTensor into the Var object. However, Var
|
||||||
|
// contains a Tensor so we can't.
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_persistent(
|
||||||
|
var.var()->tensor()->dtype(), output_tensor_shapes[i], &unused,
|
||||||
|
&output_tensor));
|
||||||
|
*var.var()->tensor() = *output_tensor;
|
||||||
|
} else {
|
||||||
|
// This output corresponds to a non-resource input to the TPUExecute
|
||||||
|
// operator. This case occurs for the distributed TPU rewrite which
|
||||||
|
// adds variable values as inputs and outputs rather than passing the
|
||||||
|
// variables themselves; reading and writing the variable is handled
|
||||||
|
// outside the op.
|
||||||
|
// TODO(phawkins): remove this case when placement of variables on TPU
|
||||||
|
// devices is well supported and we no longer need to place "remote"
|
||||||
|
// variables on CPU devices.
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_output(
|
||||||
|
op_output_index, output_tensor_shapes[i], &output_tensor));
|
||||||
|
++op_output_index;
|
||||||
|
}
|
||||||
|
transfer_buffers(i, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process any remaining non-updated variables.
|
||||||
|
for (; compiled_update_index <
|
||||||
|
variable_updates.input_in_compiled_update_order.size();
|
||||||
|
++compiled_update_index) {
|
||||||
|
process_non_updated_variable(
|
||||||
|
variable_updates.input_in_compiled_update_order[compiled_update_index]);
|
||||||
|
}
|
||||||
|
return std::move(output_buffers);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// TPUExecuteOp
|
||||||
|
|
||||||
|
TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
|
||||||
|
: AsyncOpKernel(context, /* is_deferred = */ true) {}
|
||||||
|
|
||||||
|
AsyncOpKernel* TPUExecuteOp::AsAsync() {
|
||||||
|
// If TPU launches are asynchronous, we can perform the launch without
|
||||||
|
// blocking the calling thread, and so the executor may treat this kernel as
|
||||||
|
// a regular (synchronous) OpKernel.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TPUExecuteOp::Compute(OpKernelContext* context) {
|
||||||
|
Status s = DoWork(context);
|
||||||
|
// NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
|
||||||
|
// a dynamic check that we are not in an AsyncOpKernel.
|
||||||
|
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||||
|
context->SetStatus(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
|
||||||
|
// If TPU launches are asynchronous, then perform the launch on this
|
||||||
|
// thread to avoid a thread hop, which has an observable latency cost.
|
||||||
|
OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
|
||||||
|
done();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
||||||
|
VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
|
||||||
|
|
||||||
|
const XlaDevice::Metadata* metadata;
|
||||||
|
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
|
||||||
|
const int device_ordinal = metadata->device_ordinal();
|
||||||
|
|
||||||
|
// We are guaranteed that the object underlying TpuNodeContext won't be
|
||||||
|
// deleted out from under us, while node_context is alive.
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
|
||||||
|
TpuNodeContext::Create(device_ordinal));
|
||||||
|
|
||||||
|
profiler::TraceMe trace_me(
|
||||||
|
[&, device_ordinal] {
|
||||||
|
return absl::StrCat("TpuExecuteOp#device_ordinal=", device_ordinal,
|
||||||
|
",id=", context->step_id(),
|
||||||
|
",iter_num=", context->frame_iter().iter_id, "#");
|
||||||
|
},
|
||||||
|
/*level=*/2);
|
||||||
|
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
|
||||||
|
|
||||||
|
string rendezvous_key_base;
|
||||||
|
std::unique_ptr<CompilationCacheEntryRef> entry;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetComputationCacheEntry(context, &rendezvous_key_base, &entry));
|
||||||
|
|
||||||
|
// Shapes of the inputs and outputs, in xla::Shape form.
|
||||||
|
const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
|
||||||
|
|
||||||
|
xla::TransferManager* const transfer_manager =
|
||||||
|
node_context->transfer_manager();
|
||||||
|
CHECK(context->op_device_context());
|
||||||
|
se::Stream* stream = context->op_device_context()->stream();
|
||||||
|
|
||||||
|
TF_RET_CHECK(proto->input_shapes_size() == 1);
|
||||||
|
|
||||||
|
xla::Shape host_shape(proto->input_shapes(0));
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto variable_update_map,
|
||||||
|
BuildVariableUpdateMap(proto->variable_indices(),
|
||||||
|
fused_device_var_reads_in_computation_inputs_,
|
||||||
|
fused_device_var_updates_in_computation_outputs_,
|
||||||
|
proto->output_tensor_shapes().size()));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<InputBuffers> input_buffers,
|
||||||
|
BuildComputationInputs(context, host_shape, variable_update_map,
|
||||||
|
node_context.get(), stream));
|
||||||
|
|
||||||
|
// Ideally this should be the host-to-device stream from XlaDeviceContext.
|
||||||
|
// The particular anti-dependency this is avoiding (why we need a separate
|
||||||
|
// transfer stream) is between the executable writing tuple tables and
|
||||||
|
// TPUExecute()'s deregister_stream; if they come from the same stream pool
|
||||||
|
// antidependencies will occur. XlaBackend has a different pool of streams
|
||||||
|
// to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
|
||||||
|
// will never refer to the same stream.
|
||||||
|
//
|
||||||
|
// TODO(jmolloy): Add the necessary plumbing to obtain the proper
|
||||||
|
// host-to-device stream here.
|
||||||
|
TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
|
||||||
|
node_context->BorrowStream(device_ordinal));
|
||||||
|
|
||||||
|
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
|
||||||
|
auto shaped_buffer =
|
||||||
|
input_buffers->ToShapedBuffer(host_shape, allocator, device_ordinal);
|
||||||
|
if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
|
||||||
|
shaped_buffer)) {
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
||||||
|
transfer_stream_ptr.get(), shaped_buffer));
|
||||||
|
stream->ThenWaitFor(transfer_stream_ptr.get());
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
|
||||||
|
}
|
||||||
|
VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
|
||||||
|
|
||||||
|
// Snapshot the inputs, if a snapshot was requested.
|
||||||
|
std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
|
||||||
|
if (proto->has_session_module()) {
|
||||||
|
hlo_snapshot = std::make_shared<xla::HloSnapshot>(proto->session_module());
|
||||||
|
auto literal =
|
||||||
|
std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
|
||||||
|
transfer_manager->TransferLiteralFromDevice(
|
||||||
|
stream, shaped_buffer, literal.get(),
|
||||||
|
[hlo_snapshot, literal](Status status) {
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
|
||||||
|
"failed: "
|
||||||
|
<< status;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
*hlo_snapshot->add_arguments() = literal->ToProto();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto definition_event = std::make_shared<se::Event>(stream->parent());
|
||||||
|
TF_RET_CHECK(definition_event->Init())
|
||||||
|
<< "TPU definition event initialization failed";
|
||||||
|
|
||||||
|
trace_me_init.Stop();
|
||||||
|
|
||||||
|
const uint32 rng_seed = GetXLARandomSeed();
|
||||||
|
|
||||||
|
std::unique_ptr<xla::DeviceAssignment> device_assignment;
|
||||||
|
if (proto->has_device_assignment()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
|
||||||
|
proto->device_assignment()));
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(4) << "Input buffers after alias resolution: "
|
||||||
|
<< shaped_buffer.ToString();
|
||||||
|
|
||||||
|
std::vector<xla::ExecutionInput> input;
|
||||||
|
input.emplace_back(
|
||||||
|
xla::ExecutionInput(std::move(input_buffers->buffers), host_shape));
|
||||||
|
|
||||||
|
// The buffers to be freed are in the `output` and will be automatically
|
||||||
|
// freed when it goes out of the scope. In async mode, this means the buffers
|
||||||
|
// will be freed before anyone calls "BlockHostUntilDone", which indicates
|
||||||
|
// that some of the (input) buffers will be freed while the program is running
|
||||||
|
// and looks scary. However, this turns out to be not a problem since although
|
||||||
|
// we free a memory and reassign it to other users while a program is running,
|
||||||
|
// all subsequent writes to the program that could possibly clobber the memory
|
||||||
|
// will depend on the program to finish.
|
||||||
|
const TPUHostTransferInfoProto* host_transfer_info =
|
||||||
|
entry->get().get_host_transfer_info();
|
||||||
|
const xla::HloProto* hlo_metadata = entry->get().get_hlo_metadata();
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
xla::ExecutionOutput output,
|
||||||
|
TPUExecute(*proto, *host_transfer_info, *hlo_metadata, std::move(input),
|
||||||
|
rendezvous_key_base, rng_seed, node_context.get(),
|
||||||
|
device_assignment.get(), context->cancellation_manager(),
|
||||||
|
context, stream, transfer_stream_ptr.get(),
|
||||||
|
entry->get().get_tpu_program()));
|
||||||
|
stream->ThenRecordEvent(definition_event.get());
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<OutputBuffers> output_buffers,
|
||||||
|
AllocateOutputTensors(context, output.ConsumeResult(),
|
||||||
|
proto->output_tensor_shapes(), variable_update_map,
|
||||||
|
node_context.get(), stream, device_ordinal,
|
||||||
|
input_buffers.get(), definition_event));
|
||||||
|
|
||||||
|
// Transfer the outputs and save the snapshot to disk.
|
||||||
|
if (hlo_snapshot) {
|
||||||
|
auto literal =
|
||||||
|
std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
|
||||||
|
transfer_manager->TransferLiteralFromDevice(
|
||||||
|
stream, output_buffers->buffers, literal.get(),
|
||||||
|
[hlo_snapshot, literal](Status status) {
|
||||||
|
if (status.ok()) {
|
||||||
|
*hlo_snapshot->mutable_result() = literal->ToProto();
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
|
||||||
|
"outputs failed: "
|
||||||
|
<< status;
|
||||||
|
}
|
||||||
|
DumpHloSnapshotIfEnabled(*hlo_snapshot,
|
||||||
|
xla::GetDebugOptionsFromFlags());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
TPUExecuteOp::~TPUExecuteOp() = default;
|
||||||
|
|
||||||
|
TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
|
||||||
|
OpKernelConstruction* context)
|
||||||
|
: TPUExecuteOp(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr(
|
||||||
|
"device_var_reads_indices",
|
||||||
|
&fused_device_var_reads_in_computation_inputs_));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context,
|
||||||
|
context->GetAttr("device_var_updates_indices",
|
||||||
|
&fused_device_var_updates_in_computation_outputs_));
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
|
||||||
|
.Device(DEVICE_TPU_NODE)
|
||||||
|
.HostMemory("key"),
|
||||||
|
TPUExecuteAndUpdateVariablesOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
67
tensorflow/core/tpu/kernels/tpu_execute_op.h
Normal file
67
tensorflow/core/tpu/kernels/tpu_execute_op.h
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Op that executes a precompiled TPU computation.
|
||||||
|
class TPUExecuteOp : public AsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TPUExecuteOp(OpKernelConstruction* context);
|
||||||
|
~TPUExecuteOp() override;
|
||||||
|
|
||||||
|
AsyncOpKernel* AsAsync() override;
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override;
|
||||||
|
void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Used by TPUExecuteAndUpdateVariablesOp to set the fused variable reads and
|
||||||
|
// updates indices in the XLA computation. The two vectors must have the same
|
||||||
|
// size, and a pair of read index and write index represents a variable's
|
||||||
|
// input to the program and its updated value from the program. If the
|
||||||
|
// variable is not updated, use -1 as the output index.
|
||||||
|
std::vector<int> fused_device_var_reads_in_computation_inputs_;
|
||||||
|
std::vector<int> fused_device_var_updates_in_computation_outputs_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status DoWork(OpKernelContext* context);
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(TPUExecuteOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
// A variant of TPUExecuteOp that contains fused device variable reads and
|
||||||
|
// updates.
|
||||||
|
class TPUExecuteAndUpdateVariablesOp : public TPUExecuteOp {
|
||||||
|
public:
|
||||||
|
explicit TPUExecuteAndUpdateVariablesOp(OpKernelConstruction* context);
|
||||||
|
~TPUExecuteAndUpdateVariablesOp() override = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(TPUExecuteAndUpdateVariablesOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_
|
Loading…
Reference in New Issue
Block a user