diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 209220938ed..ec5a372875c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -198,50 +198,41 @@ void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, const std::map& variables, int missing_ctx_input_prefix) { - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. arg_ptrs_ = std::vector(kernel->xla_input_shapes.size()); - // Pass remaining parameters. - const Tensor* t; + xla::TransferManager* transfer_manager = + client_->backend().transfer_manager(); for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { int arg_num = kernel->input_mapping[i]; - DCHECK_GE(arg_num, missing_ctx_input_prefix); + CHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = kernel->xla_input_shapes[i]; - if (variables.count(arg_num)) { - t = &(variables.at(arg_num).value); - CHECK(t); - } else { - t = &(ctx->input(arg_num - missing_ctx_input_prefix)); - } + const Tensor* t = variables.count(arg_num) + ? &(variables.at(arg_num).value) + : &(ctx->input(arg_num - missing_ctx_input_prefix)); + CHECK(t); if (use_multiple_streams_) { - CHECK(stream) << "Must have a stream available when using XLA tensors!"; + CHECK(ctx->op_device_context() && ctx->op_device_context()->stream()) + << "Must have a stream available when using XLA tensors!"; XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor); - xla_tensor->WaitForDefinitionEventOnStream(stream); + xla_tensor->WaitForDefinitionEventOnStream( + ctx->op_device_context()->stream()); } - const xla::Shape on_device_shape = - client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); - if (on_device_shape.IsTuple()) { - const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); - CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); - arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); - } else { - CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, - on_device_shape)) - << "On-device shape " - << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) - << " not the same as on-host shape " - << xla::ShapeUtil::HumanStringWithLayout(shape); + if (xla::Shape::Equal().MinorToMajorOnlyInLayout()( + shape, transfer_manager->HostShapeToDeviceShape(shape))) { se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); arg_buffers_.emplace_back( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); arg_buffers_.back().set_buffer(dmem, /*index=*/{}); arg_ptrs_[i] = &arg_buffers_.back(); + } else { + const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); + arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); } } }