[TF/XLA Bridge] [NFC] Reduce the amount of boilerplate for getting-or-creating output tensor

PiperOrigin-RevId: 272083303
This commit is contained in:
George Karpenkov 2019-09-30 15:48:39 -07:00 committed by TensorFlower Gardener
parent 964ee5e31d
commit 86902a8ada

View File

@ -269,7 +269,7 @@ static bool MustAliasOutput(
static const Tensor* FindAliasedTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::vector<int>& input_mapping,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
if (MustAliasOutput(input_output_alias, output_num)) {
int xla_param = input_output_alias.GetAliasedParameter({output_num})
@ -301,6 +301,23 @@ static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
return t;
}
// Get aliased tensor, or make a new one for the corresponding output operation.
static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots,
DataType output_dtype, const TensorShape& output_shape,
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
input_mapping, resource_var_snapshots)) {
return *aliased_tensor;
}
return MakeTensor(output_dtype, output_shape, output_buffer,
output_allocator);
}
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
@ -423,21 +440,12 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
absl::optional<Tensor> output_tensor_storage;
const Tensor* output_tensor;
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots)) {
output_tensor = aliased_tensor;
} else {
output_tensor_storage = MakeTensor(ctx->expected_output_dtype(i),
shape, buffer, allocator);
output_tensor = &output_tensor_storage.value();
}
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, *output_tensor);
ctx->set_output(i, output_tensor);
}
++output_num;
}
@ -502,21 +510,11 @@ Status XlaComputationLaunchContext::PopulateOutputs(
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
absl::optional<Tensor> output_tensor_storage;
const Tensor* output_tensor;
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots)) {
output_tensor = aliased_tensor;
} else {
output_tensor_storage =
MakeTensor(write.type, write.shape, buffer, allocator);
output_tensor = &output_tensor_storage.value();
}
*variable_infos[i].var()->tensor() = *output_tensor;
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots, write.type,
write.shape, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor;
}
++output_num;
}