[TF/XLA Bridge] [NFC] Reduce the amount of boilerplate for getting-or-creating output tensor
PiperOrigin-RevId: 272083303
This commit is contained in:
parent
964ee5e31d
commit
86902a8ada
@ -269,7 +269,7 @@ static bool MustAliasOutput(
|
|||||||
static const Tensor* FindAliasedTensorForOutput(
|
static const Tensor* FindAliasedTensorForOutput(
|
||||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
const std::vector<int>& input_mapping,
|
absl::Span<const int> input_mapping,
|
||||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||||
@ -301,6 +301,23 @@ static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
|||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get aliased tensor, or make a new one for the corresponding output operation.
|
||||||
|
static Tensor GetOrCreateTensorForOutput(
|
||||||
|
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
|
absl::Span<const int> input_mapping,
|
||||||
|
const std::map<int, OptionalTensor>& resource_var_snapshots,
|
||||||
|
DataType output_dtype, const TensorShape& output_shape,
|
||||||
|
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||||
|
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
||||||
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
|
input_mapping, resource_var_snapshots)) {
|
||||||
|
return *aliased_tensor;
|
||||||
|
}
|
||||||
|
return MakeTensor(output_dtype, output_shape, output_buffer,
|
||||||
|
output_allocator);
|
||||||
|
}
|
||||||
|
|
||||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||||
@ -423,21 +440,12 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||||
absl::optional<Tensor> output_tensor_storage;
|
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||||
const Tensor* output_tensor;
|
|
||||||
|
|
||||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
|
||||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
kernel->input_mapping, resource_var_snapshots)) {
|
kernel->input_mapping, resource_var_snapshots,
|
||||||
output_tensor = aliased_tensor;
|
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||||
} else {
|
|
||||||
output_tensor_storage = MakeTensor(ctx->expected_output_dtype(i),
|
|
||||||
shape, buffer, allocator);
|
|
||||||
output_tensor = &output_tensor_storage.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||||
ctx->set_output(i, *output_tensor);
|
ctx->set_output(i, output_tensor);
|
||||||
}
|
}
|
||||||
++output_num;
|
++output_num;
|
||||||
}
|
}
|
||||||
@ -502,21 +510,11 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
} else {
|
} else {
|
||||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||||
|
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||||
absl::optional<Tensor> output_tensor_storage;
|
|
||||||
const Tensor* output_tensor;
|
|
||||||
|
|
||||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
|
||||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
kernel->input_mapping, resource_var_snapshots)) {
|
kernel->input_mapping, resource_var_snapshots, write.type,
|
||||||
output_tensor = aliased_tensor;
|
write.shape, buffer, allocator);
|
||||||
} else {
|
*variable_infos[i].var()->tensor() = output_tensor;
|
||||||
output_tensor_storage =
|
|
||||||
MakeTensor(write.type, write.shape, buffer, allocator);
|
|
||||||
output_tensor = &output_tensor_storage.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
*variable_infos[i].var()->tensor() = *output_tensor;
|
|
||||||
}
|
}
|
||||||
++output_num;
|
++output_num;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user