[TF/XLA Bridge] [NFC] Reduce the amount of boilerplate for getting-or-creating output tensor
PiperOrigin-RevId: 272083303
This commit is contained in:
parent
964ee5e31d
commit
86902a8ada
@ -269,7 +269,7 @@ static bool MustAliasOutput(
|
||||
static const Tensor* FindAliasedTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::vector<int>& input_mapping,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||
@ -301,6 +301,23 @@ static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
||||
return t;
|
||||
}
|
||||
|
||||
// Get aliased tensor, or make a new one for the corresponding output operation.
|
||||
static Tensor GetOrCreateTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots,
|
||||
DataType output_dtype, const TensorShape& output_shape,
|
||||
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
input_mapping, resource_var_snapshots)) {
|
||||
return *aliased_tensor;
|
||||
}
|
||||
return MakeTensor(output_dtype, output_shape, output_buffer,
|
||||
output_allocator);
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
@ -423,21 +440,12 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
}
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
absl::optional<Tensor> output_tensor_storage;
|
||||
const Tensor* output_tensor;
|
||||
|
||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots)) {
|
||||
output_tensor = aliased_tensor;
|
||||
} else {
|
||||
output_tensor_storage = MakeTensor(ctx->expected_output_dtype(i),
|
||||
shape, buffer, allocator);
|
||||
output_tensor = &output_tensor_storage.value();
|
||||
}
|
||||
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, *output_tensor);
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
}
|
||||
@ -502,21 +510,11 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
|
||||
absl::optional<Tensor> output_tensor_storage;
|
||||
const Tensor* output_tensor;
|
||||
|
||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots)) {
|
||||
output_tensor = aliased_tensor;
|
||||
} else {
|
||||
output_tensor_storage =
|
||||
MakeTensor(write.type, write.shape, buffer, allocator);
|
||||
output_tensor = &output_tensor_storage.value();
|
||||
}
|
||||
|
||||
*variable_infos[i].var()->tensor() = *output_tensor;
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots, write.type,
|
||||
write.shape, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
}
|
||||
++output_num;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user