[TF/XLA] [NFC] Simplify XlaComputationLaunchContext::PopulateInputs
Try to make the logic more transparent PiperOrigin-RevId: 316713452 Change-Id: I41e8d691e6cab9c7a6b5bc40d50d660fcbe05906
This commit is contained in:
parent
c42314ef70
commit
08e445e37f
@ -198,50 +198,41 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
const Tensor* t;
|
||||
xla::TransferManager* transfer_manager =
|
||||
client_->backend().transfer_manager();
|
||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = kernel->input_mapping[i];
|
||||
DCHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = kernel->xla_input_shapes[i];
|
||||
if (variables.count(arg_num)) {
|
||||
t = &(variables.at(arg_num).value);
|
||||
CHECK(t);
|
||||
} else {
|
||||
t = &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
}
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value)
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
CHECK(t);
|
||||
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(stream) << "Must have a stream available when using XLA tensors!";
|
||||
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
||||
<< "Must have a stream available when using XLA tensors!";
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(stream);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(
|
||||
ctx->op_device_context()->stream());
|
||||
}
|
||||
|
||||
const xla::Shape on_device_shape =
|
||||
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
|
||||
if (on_device_shape.IsTuple()) {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
} else {
|
||||
CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
|
||||
on_device_shape))
|
||||
<< "On-device shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
|
||||
<< " not the same as on-host shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(shape);
|
||||
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
} else {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user