Allows showing informative error messages, indicating where the resource var was created. PiperOrigin-RevId: 353729467 Change-Id: I90c19bfecc2d2df19523b28a0caeea2abe331fae
672 lines
27 KiB
C++
672 lines
27 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
|
|
|
#include <memory>
|
|
|
|
#include "absl/algorithm/container.h"
|
|
#include "absl/memory/memory.h"
|
|
#include "tensorflow/compiler/jit/defs.h"
|
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
|
#include "tensorflow/core/common_runtime/function.h"
|
|
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
|
#include "tensorflow/core/framework/allocator.h"
|
|
#include "tensorflow/core/framework/node_def_util.h"
|
|
#include "tensorflow/core/framework/op.h"
|
|
#include "tensorflow/core/framework/op_kernel.h"
|
|
#include "tensorflow/core/framework/resource_mgr.h"
|
|
#include "tensorflow/core/framework/tensor.h"
|
|
#include "tensorflow/core/framework/types.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/core/refcount.h"
|
|
#include "tensorflow/core/util/stream_executor_util.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
using xla::ScopedShapedBuffer;
|
|
using xla::ShapedBuffer;
|
|
|
|
} // anonymous namespace
|
|
|
|
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
|
|
: index_(index), name_(name), var_(var) {}
|
|
VariableInfo::VariableInfo(VariableInfo&& other)
|
|
: index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
|
|
other.index_ = -1;
|
|
other.var_ = nullptr;
|
|
}
|
|
|
|
VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
|
|
index_ = other.index_;
|
|
var_ = other.var_;
|
|
lock_held_ = other.lock_held_;
|
|
|
|
other.index_ = -1;
|
|
other.var_ = nullptr;
|
|
|
|
return *this;
|
|
}
|
|
|
|
VariableInfo::~VariableInfo() {
|
|
// Release the variable's lock if we hold it. Ensures that the lock is
|
|
// released even on error. It does not matter in what order we release the
|
|
// locks.
|
|
if (var()) {
|
|
if (lock_held()) {
|
|
var()->mu()->unlock();
|
|
}
|
|
|
|
// Unref the variable so it can be released by ResourceManager.
|
|
var()->Unref();
|
|
}
|
|
}
|
|
|
|
Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
|
|
absl::Span<const Tensor* const> inputs,
|
|
absl::Span<const int> variable_indices,
|
|
std::vector<VariableInfo>* result) {
|
|
result->clear();
|
|
result->reserve(variable_indices.size());
|
|
for (int var_idx : variable_indices) {
|
|
Var* variable = nullptr;
|
|
ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
|
|
if (handle.device() != dev->attributes().name()) {
|
|
std::string definition_location = [&]() -> std::string {
|
|
if (handle.definition_stack_trace()) {
|
|
std::vector<StackFrame> stack_frames =
|
|
handle.definition_stack_trace()->ToStackFrames(
|
|
{}, IsInternalFrameForFilename,
|
|
/*reverse_traversal=*/true,
|
|
/*limit=*/1);
|
|
if (!stack_frames.empty()) {
|
|
const StackFrame& last_frame = stack_frames[0];
|
|
return absl::StrCat(" (defined @ ", last_frame.file_name, ":",
|
|
last_frame.line_number, ")");
|
|
}
|
|
}
|
|
return "";
|
|
}();
|
|
return errors::InvalidArgument("Trying to access resource ",
|
|
handle.name(), definition_location,
|
|
" located in device ", handle.device(),
|
|
" from device ", dev->attributes().name());
|
|
}
|
|
TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
|
|
handle.container(), handle.name(), &variable, [](Var** ptr) {
|
|
// This var is uninitialized for now.
|
|
*ptr = new Var(DT_INVALID);
|
|
return Status::OK();
|
|
}));
|
|
result->emplace_back(var_idx, handle.name(), variable);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
|
|
std::vector<const Tensor*> inputs;
|
|
inputs.reserve(ctx->num_inputs());
|
|
for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
|
|
inputs.push_back(&ctx->input(input_idx));
|
|
}
|
|
return inputs;
|
|
}
|
|
|
|
Status LockVariables(absl::Span<VariableInfo> variables) {
|
|
std::vector<int> lock_order(variables.size());
|
|
std::iota(lock_order.begin(), lock_order.end(), 0);
|
|
|
|
// VariableInfoComparator orders all empty VariableInfo instances as
|
|
// equivalent so it looks like we may want to stable sort these to maintain a
|
|
// deterministic order between the empty VariableInfo instances. However
|
|
// since we're sorting by pointer value the sort is pretty non-deterministic
|
|
// anyway so we don't bother using std::stable_sort for now.
|
|
absl::c_sort(lock_order, [&](int a, int b) {
|
|
if (variables[a].var() && variables[b].var()) {
|
|
return variables[a].var()->mu() < variables[b].var()->mu();
|
|
}
|
|
|
|
// Move all the empty VariableInfo instances to the end.
|
|
return variables[a].var() != nullptr;
|
|
});
|
|
|
|
mutex* prev = nullptr;
|
|
for (int i : lock_order) {
|
|
Var* variable = variables[i].var();
|
|
if (variable == nullptr) {
|
|
// All empty VariableInfo instances are at the end of the order
|
|
// so we're done.
|
|
break;
|
|
}
|
|
mutex* mu = variable->mu();
|
|
if (prev == mu) {
|
|
// It is an error to pass the same variable handle twice to the same XLA
|
|
// cluster because we would not handle variable updates correctly. Any
|
|
// locks we have already acquired will be released when the VariableInfo
|
|
// objects are destroyed.
|
|
// TODO(b/128495870) Add support for passing aliased resource variables.
|
|
return errors::Unimplemented("Duplicate variable passed to XLA cluster");
|
|
}
|
|
VLOG(4) << "Acquiring lock for variable "
|
|
<< reinterpret_cast<void*>(variable);
|
|
mu->lock();
|
|
variables[i].set_lock_held();
|
|
prev = mu;
|
|
}
|
|
VLOG(4) << "Finished acquiring variable locks.";
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
|
absl::Span<const int> variable_indices,
|
|
absl::Span<VariableInfo const> variable_infos,
|
|
ResourceVarsSnapshot* result) {
|
|
for (int i = 0, end = variable_indices.size(); i < end; i++) {
|
|
Var* var = variable_infos[i].var();
|
|
(*result)[variable_indices[i]] =
|
|
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
|
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
|
int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
|
|
: client_(client),
|
|
xla_allocator_(xla_allocator),
|
|
allocate_xla_tensors_(allocate_xla_tensors),
|
|
use_multiple_streams_(use_multiple_streams),
|
|
device_ordinal_(device_ordinal) {
|
|
if (use_multiple_streams_) {
|
|
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
|
|
"be allocating XLA tensors!";
|
|
}
|
|
}
|
|
|
|
// Fills in `execution_input` with `buffer` for `index`.
|
|
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
|
|
xla::ShapeIndex index,
|
|
se::DeviceMemoryBase& buffer,
|
|
bool donate_buffer, int device_ordinal,
|
|
se::DeviceMemoryAllocator* allocator) {
|
|
xla::MaybeOwningDeviceMemory* in_buffer =
|
|
execution_input.MutableBuffer(index);
|
|
if (donate_buffer) {
|
|
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
|
|
buffer = se::DeviceMemoryBase();
|
|
} else {
|
|
*in_buffer = buffer;
|
|
}
|
|
}
|
|
|
|
xla::StatusOr<std::vector<xla::ExecutionInput>>
|
|
XlaComputationLaunchContext::PopulateInputs(
|
|
OpKernelContext* ctx,
|
|
const XlaCompiler::CompilationResult* compilation_result,
|
|
const std::map<int, const Tensor*>& resource_vars,
|
|
int missing_ctx_input_prefix,
|
|
const xla::HloInputOutputAliasConfig& input_output_alias) {
|
|
std::vector<xla::ExecutionInput> arguments;
|
|
arguments.reserve(compilation_result->xla_input_shapes.size());
|
|
|
|
xla::TransferManager* transfer_manager =
|
|
client_->backend().transfer_manager();
|
|
for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
|
|
++i) {
|
|
int arg_num = compilation_result->input_mapping[i];
|
|
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
|
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
|
const xla::Shape& device_shape =
|
|
transfer_manager->HostShapeToDeviceShape(shape);
|
|
|
|
bool is_resource_variable = resource_vars.count(arg_num);
|
|
bool is_updated_resource_variable =
|
|
is_resource_variable &&
|
|
absl::c_any_of(compilation_result->resource_updates,
|
|
[&](const XlaCompiler::ResourceUpdate& update) {
|
|
return update.input_index == i && update.modified;
|
|
});
|
|
|
|
const Tensor* t = is_resource_variable
|
|
? resource_vars.at(arg_num)
|
|
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
|
CHECK(t);
|
|
bool donate_buffer =
|
|
t->RefCountIsOne() && is_updated_resource_variable &&
|
|
input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
|
|
VLOG(3) << "Processing input: " << i
|
|
<< "; is_resource_variable=" << is_resource_variable
|
|
<< "; is_updated_resource_variable=" << is_updated_resource_variable
|
|
<< "; donate_buffer=" << donate_buffer;
|
|
|
|
if (use_multiple_streams_) {
|
|
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
|
<< "Must have a stream available when using XLA tensors!";
|
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
|
CHECK(xla_tensor);
|
|
xla_tensor->WaitForDefinitionEventOnStream(
|
|
ctx->op_device_context()->stream());
|
|
}
|
|
|
|
arguments.emplace_back(device_shape, shape);
|
|
xla::ExecutionInput& execution_input = arguments.back();
|
|
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
|
|
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
|
PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
|
|
donate_buffer, device_ordinal_,
|
|
xla_allocator_);
|
|
} else {
|
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
|
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
|
xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
|
|
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
|
PopulateExecutionInputBuffer(execution_input, index, *buffer,
|
|
donate_buffer, device_ordinal_,
|
|
xla_allocator_);
|
|
});
|
|
}
|
|
}
|
|
return std::move(arguments);
|
|
}
|
|
|
|
// Construct the tensor for the given type and buffer.
|
|
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
|
se::DeviceMemoryBase buffer, Allocator* allocator) {
|
|
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
|
|
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
|
|
buffer.size(), allocator);
|
|
Tensor t(dtype, shape, tensor_buffer);
|
|
tensor_buffer->Unref();
|
|
return t;
|
|
}
|
|
|
|
// Get aliased tensor, or make a new one for the corresponding output operation.
|
|
static Tensor GetOrCreateTensorForOutput(
|
|
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
|
absl::Span<const int> input_mapping,
|
|
const std::map<int, const Tensor*>& resource_vars_snapshots,
|
|
DataType output_dtype, const TensorShape& output_shape,
|
|
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
|
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
|
|
? xla::ShapeIndex({output_num})
|
|
: xla::ShapeIndex({});
|
|
|
|
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
|
|
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
|
|
input_output_alias.GetAliasedParameter(output_index)) {
|
|
VLOG(3) << "Found alias: " << alias->ToString();
|
|
int tf_param =
|
|
input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
|
|
const Tensor input_tensor =
|
|
ctx->input(tf_param).dtype() != DT_RESOURCE
|
|
? ctx->input(tf_param)
|
|
: *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
|
|
if (output_buffer.opaque() == input_tensor.data()) {
|
|
return input_tensor;
|
|
}
|
|
}
|
|
return MakeTensor(output_dtype, output_shape, output_buffer,
|
|
output_allocator);
|
|
}
|
|
|
|
static void PopulateXlaTensor(Tensor* output_tensor,
|
|
xla::ScopedShapedBuffer* output, int output_num,
|
|
se::Stream* stream, bool use_multiple_streams,
|
|
std::shared_ptr<se::Event> definition_event) {
|
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
|
CHECK(xla_tensor);
|
|
xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num}));
|
|
if (use_multiple_streams) {
|
|
xla_tensor->ResetDefinitionEvent(definition_event, stream);
|
|
}
|
|
}
|
|
|
|
// Sets output `output_num` for `ctx` provided it is known at a compile time.
|
|
static Status SetOutputForConstant(
|
|
OpKernelContext* ctx, se::Stream* stream,
|
|
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
|
CHECK(compilation_result->outputs[output_num].is_constant);
|
|
const Tensor& const_tensor =
|
|
compilation_result->outputs[output_num].constant_value;
|
|
Tensor* output_tensor;
|
|
if (stream && const_tensor.TotalBytes() > 0) {
|
|
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
|
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
|
// as much memory as the device requires (as given by
|
|
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
|
// reallocate the device buffer later.
|
|
VLOG(1) << "Constant output tensor on device";
|
|
|
|
TF_RETURN_IF_ERROR(
|
|
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
|
|
Device* device = dynamic_cast<Device*>(ctx->device());
|
|
if (device == nullptr) {
|
|
return errors::Internal("DeviceBase was not a Device.");
|
|
}
|
|
ctx->op_device_context()->CopyCPUTensorToDevice(
|
|
&const_tensor, device, output_tensor,
|
|
[&](Status status) { TF_CHECK_OK(status); });
|
|
|
|
if (device->device_type() == DEVICE_GPU) {
|
|
// The GPUDeviceContext enqueues the host->device transfer in a
|
|
// separate stream from the main compute stream. We must ensure the
|
|
// compute stream is synchronized with the host->device transfer
|
|
// stream now otherwise we will create a race condition.
|
|
auto* gpu_device_context =
|
|
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
|
gpu_device_context->stream()->ThenWaitFor(
|
|
gpu_device_context->host_to_device_stream());
|
|
}
|
|
} else {
|
|
// No copy required.
|
|
ctx->set_output(output_num, const_tensor);
|
|
output_tensor = ctx->mutable_output(output_num);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
static xla::StatusOr<Var*> GetOrCreateResourceVar(
|
|
OpKernelContext* ctx, const ResourceHandle& handle,
|
|
const XlaCompiler::ResourceUpdate& write) {
|
|
Var* variable = nullptr;
|
|
TF_RETURN_IF_ERROR(
|
|
LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
|
|
*ptr = new Var(write.type);
|
|
return Status::OK();
|
|
}));
|
|
return variable;
|
|
}
|
|
|
|
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
|
OpKernelContext* ctx,
|
|
const XlaCompiler::CompilationResult& compilation_result,
|
|
int missing_ctx_input_prefix) {
|
|
std::vector<VariableInfo> out;
|
|
out.reserve(compilation_result.resource_updates.size());
|
|
for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
|
|
const XlaCompiler::ResourceUpdate& write =
|
|
compilation_result.resource_updates[i];
|
|
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
|
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
|
return errors::Internal("Invalid input index for variable write.");
|
|
}
|
|
|
|
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
|
|
TF_ASSIGN_OR_RETURN(Var * variable,
|
|
GetOrCreateResourceVar(ctx, handle, write));
|
|
out.emplace_back(actual_input_index, handle.name(), variable);
|
|
}
|
|
return std::move(out);
|
|
}
|
|
|
|
Status XlaComputationLaunchContext::PopulateOutputs(
|
|
OpKernelContext* ctx,
|
|
const XlaCompiler::CompilationResult* compilation_result,
|
|
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
|
absl::Span<VariableInfo> variable_infos,
|
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
|
const std::map<int, const Tensor*>& resource_vars) {
|
|
se::Stream* stream =
|
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
|
Allocator* allocator = ctx->device()->GetAllocator({});
|
|
|
|
// Computation output should always be a tuple.
|
|
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
|
|
VLOG(2) << "Result tuple shape (on device): "
|
|
<< output.on_device_shape().DebugString();
|
|
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
|
|
|
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
|
// buffer with a nullptr root index table. This allows the code below to treat
|
|
// output as a tuple unconditionally.
|
|
if (!output.on_host_shape().IsTuple()) {
|
|
ShapedBuffer nontuple_buffer = output.release();
|
|
ShapedBuffer buffer(
|
|
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
|
|
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
|
|
output.device_ordinal());
|
|
buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
|
|
/*source_base_index=*/{},
|
|
/*target_base_index=*/{0});
|
|
output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
|
|
}
|
|
|
|
std::shared_ptr<se::Event> definition_event;
|
|
if (use_multiple_streams_) {
|
|
definition_event = std::make_shared<se::Event>(stream->parent());
|
|
if (!definition_event->Init()) {
|
|
return errors::Internal("Failed to initialize tensor definition event.");
|
|
}
|
|
stream->ThenRecordEvent(definition_event.get());
|
|
}
|
|
|
|
std::vector<TensorShape> output_tensor_shapes;
|
|
output_tensor_shapes.reserve(ctx->num_outputs());
|
|
if (output.on_host_shape().is_dynamic()) {
|
|
TF_ASSIGN_OR_RETURN(
|
|
auto transfer_manager,
|
|
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
|
|
|
xla::Shape output_device_shape = output.on_device_shape();
|
|
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
|
stream, &output, &output_device_shape));
|
|
|
|
output.set_shapes(output_device_shape, output_device_shape);
|
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
|
const xla::Shape& subshape =
|
|
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
|
|
TensorShape shape;
|
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
|
output_tensor_shapes.push_back(shape);
|
|
}
|
|
} else {
|
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
|
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
|
|
}
|
|
}
|
|
|
|
// Copy XLA results to the OpOutputList.
|
|
int output_num = 0;
|
|
for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
|
|
const TensorShape& shape = output_tensor_shapes[i];
|
|
const DataType& type = compilation_result->outputs[i].type;
|
|
VLOG(2) << "Populating output for retval " << i << " shape "
|
|
<< shape.DebugString() << " type " << DataTypeString(type);
|
|
if (type == DT_VARIANT) {
|
|
return errors::Unimplemented(
|
|
"Support for TensorList crossing the XLA/TF boundary "
|
|
"is not implemented");
|
|
}
|
|
|
|
if (compilation_result->outputs[i].is_constant) {
|
|
TF_RETURN_IF_ERROR(
|
|
SetOutputForConstant(ctx, stream, compilation_result, i));
|
|
} else if (type == DT_RESOURCE) {
|
|
int input_index =
|
|
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
|
|
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
|
<< "Invalid input for outputs " << i << ": " << input_index;
|
|
ctx->set_output(i, ctx->input(input_index));
|
|
} else {
|
|
if (allocate_xla_tensors_) {
|
|
Tensor* output_tensor;
|
|
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
|
if (output_tensor->TotalBytes() > 0) {
|
|
PopulateXlaTensor(output_tensor, &output, output_num, stream,
|
|
use_multiple_streams_, definition_event);
|
|
}
|
|
} else {
|
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
|
Tensor output_tensor = GetOrCreateTensorForOutput(
|
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
|
compilation_result->input_mapping, resource_vars,
|
|
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
|
ctx->set_output(i, output_tensor);
|
|
}
|
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
|
++output_num;
|
|
}
|
|
}
|
|
|
|
// input_index -> index into variable_infos.
|
|
absl::flat_hash_map<int, int> variable_info_lookup;
|
|
for (int i = 0; i < variable_infos.size(); i++) {
|
|
variable_info_lookup.emplace(variable_infos[i].index(), i);
|
|
}
|
|
|
|
// Apply variable updates, if any.
|
|
for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
|
|
++i) {
|
|
const XlaCompiler::ResourceUpdate& write =
|
|
compilation_result->resource_updates[i];
|
|
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
|
CHECK_GE(actual_input_index, 0);
|
|
CHECK_LT(actual_input_index, ctx->num_inputs());
|
|
Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
|
|
CHECK(var);
|
|
|
|
VLOG(2) << "Updating variable #" << i
|
|
<< " at input index: " << actual_input_index << " with shape "
|
|
<< write.shape.DebugString() << "; variable tensor has shape: "
|
|
<< var->tensor()->shape().DebugString();
|
|
|
|
if (var->is_initialized && var->tensor()->dtype() != write.type) {
|
|
return errors::Internal("Mismatched type in variable write");
|
|
}
|
|
|
|
Tensor output_tensor;
|
|
if (allocate_xla_tensors_) {
|
|
TF_RETURN_IF_ERROR(
|
|
ctx->allocate_temp(write.type, write.shape, &output_tensor));
|
|
if (write.shape.num_elements() > 0) {
|
|
PopulateXlaTensor(&output_tensor, &output, output_num, stream,
|
|
use_multiple_streams_, definition_event);
|
|
}
|
|
} else {
|
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
|
output_tensor = GetOrCreateTensorForOutput(
|
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
|
compilation_result->input_mapping, resource_vars, write.type,
|
|
write.shape, buffer, allocator);
|
|
}
|
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
|
var->is_initialized |= write.modified;
|
|
*var->tensor() = output_tensor;
|
|
++output_num;
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
|
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
|
absl::Span<int const> must_be_constant_idxs,
|
|
absl::Span<const Tensor* const> inputs,
|
|
absl::Span<VariableInfo const> variable_args, Device* device) {
|
|
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
|
std::vector<XlaCompiler::Argument> out;
|
|
out.resize(inputs.size());
|
|
|
|
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
|
|
DeviceContext* device_context = nullptr;
|
|
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
|
|
bool using_default_context = false;
|
|
auto cleanup = xla::MakeCleanup([&] {
|
|
if (device_context != nullptr && !using_default_context) {
|
|
device_context->Unref();
|
|
}
|
|
});
|
|
if (device_context == nullptr) {
|
|
using_default_context = true;
|
|
auto* dev_info = device->tensorflow_gpu_device_info();
|
|
if (dev_info) device_context = dev_info->default_context;
|
|
}
|
|
|
|
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
|
for (const VariableInfo& info : variable_args) {
|
|
CHECK(!info.var() || info.lock_held())
|
|
<< "Need to hold the lock on resource variables "
|
|
"before calling BuildXlaCompilerArguments";
|
|
variable_info_lookup.emplace(info.index(), &info);
|
|
}
|
|
|
|
for (int64 input_num = 0; input_num < inputs.size(); ++input_num) {
|
|
const Tensor* input = inputs[input_num];
|
|
|
|
XlaCompiler::Argument& arg = out[input_num];
|
|
if (variable_info_lookup.count(input_num)) {
|
|
// Handles resource variables.
|
|
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
|
const VariableInfo& variable = *variable_info_lookup[input_num];
|
|
arg.name = std::string(variable.name());
|
|
arg.kind = XlaCompiler::Argument::kResource;
|
|
arg.resource_kind = XlaResource::kVariable;
|
|
if (variable.var() && variable.var()->is_initialized) {
|
|
const Tensor* value = variable.var()->tensor();
|
|
arg.type = value->dtype();
|
|
arg.shape = value->shape();
|
|
arg.initialized = true;
|
|
} else {
|
|
// The values of uninitialized variables are not passed as inputs, since
|
|
// they are meaningless. However, it is legal to assign to a resource
|
|
// variable for the first time inside the XLA computation, so we do
|
|
// permit uninitialized variables.
|
|
arg.initialized = false;
|
|
arg.type = DT_INVALID;
|
|
arg.shape = TensorShape();
|
|
}
|
|
|
|
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
|
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
|
|
const Tensor* value = variable.var()->tensor();
|
|
Tensor value_on_host(value->dtype(), value->shape());
|
|
if (!device_context) {
|
|
value_on_host = *value;
|
|
} else {
|
|
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
|
|
value, "", device, &value_on_host));
|
|
}
|
|
arg.kind = XlaCompiler::Argument::kConstantResource;
|
|
arg.constant_value = value_on_host;
|
|
}
|
|
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
|
arg.kind = XlaCompiler::Argument::kConstant;
|
|
arg.type = input->dtype();
|
|
arg.shape = input->shape();
|
|
arg.constant_value = *input;
|
|
} else {
|
|
// Normal inputs.
|
|
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
|
if (input->NumElements() > 0) {
|
|
arg.kind = XlaCompiler::Argument::kParameter;
|
|
} else {
|
|
arg.kind = XlaCompiler::Argument::kConstant;
|
|
arg.constant_value = *input;
|
|
}
|
|
arg.type = input->dtype();
|
|
arg.shape = input->shape();
|
|
}
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
} // namespace tensorflow
|