diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 48347a2915f..38e33a60657 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -277,7 +277,8 @@ static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, absl::Span variable_infos, - absl::Span constants, bool lazy, xla::LocalClient** client, + absl::Span constants, bool lazy, bool may_alias_resource_update, + xla::LocalClient** client, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { // We store information about the JIT-compiled XLA computation @@ -332,6 +333,9 @@ static Status CompileToLocalExecutable( // Optimization: where possible, have the computation return a naked array // rather than a one-element tuple. compile_options.always_return_tuple = false; + compile_options.alias_resource_update = !has_ref_vars && + !platform_info.is_on_xla_device() && + may_alias_resource_update; std::vector args; TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( @@ -350,20 +354,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; - ResourceVarsSnapshot variables_snapshot; + std::vector variable_infos; { - std::vector variable_infos; OP_REQUIRES_OK( ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); Status s = CompileToLocalExecutable( ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, - variable_infos, constants_, /*lazy=*/false, &client, - &compilation_result, &executable); + variable_infos, constants_, /*lazy=*/false, + /*may_alias_resource_update=*/true, &client, &compilation_result, + &executable); OP_REQUIRES_OK(ctx, s); - OP_REQUIRES_OK(ctx, - SnapshotResourceVariables(ctx, resources_, variable_infos, - &variables_snapshot)); + } + + std::map resource_var_ptrs; + for (int i = 0; i < resources_.size(); i++) { + resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor(); } se::Stream* stream = @@ -374,12 +380,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { absl::optional tf_allocator_adapter; se::DeviceMemoryAllocator* allocator = GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + int device_ordinal = stream ? stream->parent()->device_ordinal() + : client->default_device_ordinal(); XlaComputationLaunchContext launch_context( - client, allocator, + client, allocator, device_ordinal, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), platform_info_.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot, - /*missing_ctx_input_prefix=*/0); + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + xla::StatusOr> execution_inputs = + launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs, + /*missing_ctx_input_prefix=*/0, + input_output_alias); + OP_REQUIRES_OK(ctx, execution_inputs.status()); // Execute the computation. VLOG(2) << "Executing computation."; @@ -403,24 +416,24 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { Env* env = Env::Default(); auto start_time = env->NowMicros(); - xla::StatusOr run_result; + xla::StatusOr execution_output; if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { - run_result = executable->Run(launch_context.arguments(), run_options); + execution_output = + executable->Run(std::move(*execution_inputs), run_options); } else { - run_result = executable->RunAsync(launch_context.arguments(), run_options); + execution_output = + executable->RunAsync(std::move(*execution_inputs), run_options); } - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + OP_REQUIRES(ctx, execution_output.ok(), execution_output.status()); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; + OP_REQUIRES_OK( + ctx, launch_context.PopulateOutputs( + ctx, compilation_result, execution_output->ConsumeResult(), + /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos), + input_output_alias, resource_var_ptrs)); - const xla::HloInputOutputAliasConfig& input_output_alias = - executable->executable()->module().input_output_alias_config(); - OP_REQUIRES_OK(ctx, - launch_context.PopulateOutputs( - ctx, compilation_result, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0, input_output_alias, - variables_snapshot)); VLOG(1) << "Done"; } @@ -516,10 +529,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); + + // Do not alias resource updates as locking variables in XlaCompile and + // unlocking them in XlaRun may lead to deadlocks. Status status = CompileToLocalExecutable( ctx, function_, has_ref_vars_, platform_info_, variable_infos, constants_, - /*lazy=*/!must_compile_, &client, &kernel, &executable); + /*lazy=*/!must_compile_, + /*may_alias_resource_update=*/false, &client, &kernel, &executable); OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_, variable_infos, &variables)); if (must_compile_ || status.code() != error::UNIMPLEMENTED) { @@ -587,14 +604,22 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { absl::optional tf_allocator_adapter; se::DeviceMemoryAllocator* allocator = GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + int device_ordinal = stream ? stream->parent()->device_ordinal() + : closure.client()->default_device_ordinal(); XlaComputationLaunchContext launch_context( - closure.client(), allocator, + closure.client(), allocator, device_ordinal, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); // We're missing the must-be-constant inputs, tell `PopulateInputs` // about this. We don't actually need these inputs because they've // already been baked into the compiled kernel. + const xla::HloInputOutputAliasConfig& input_output_alias = + closure.executable()->executable()->module().input_output_alias_config(); + xla::StatusOr> execution_inputs; + std::map snapshot_ptrs; { tensorflow::profiler::TraceMe hlo_module_activity( [&] { @@ -604,13 +629,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { }, tensorflow::profiler::TraceMeLevel::kInfo); - launch_context.PopulateInputs( - ctx, closure.compilation_result(), closure.resource_var_snapshots(), - /*missing_ctx_input_prefix=*/closure.num_constant_args()); + for (auto& p : closure.resource_var_snapshots()) { + snapshot_ptrs.emplace(p.first, + p.second.has_value() ? &p.second.value() : nullptr); + } + execution_inputs = launch_context.PopulateInputs( + ctx, closure.compilation_result(), snapshot_ptrs, + /*missing_ctx_input_prefix=*/closure.num_constant_args(), + input_output_alias); + OP_REQUIRES_OK(ctx, execution_inputs.status()); } - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(allocator); @@ -631,21 +660,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { Env* env = Env::Default(); auto start_time = env->NowMicros(); - xla::StatusOr run_result; + xla::StatusOr execution_output; if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { - run_result = - closure.executable()->Run(launch_context.arguments(), run_options); + execution_output = + closure.executable()->Run(std::move(*execution_inputs), run_options); } else { - run_result = - closure.executable()->RunAsync(launch_context.arguments(), run_options); + execution_output = closure.executable()->RunAsync( + std::move(*execution_inputs), run_options); } - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + OP_REQUIRES(ctx, execution_output.ok(), execution_output.status()); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; - const xla::HloInputOutputAliasConfig& input_output_alias = - closure.executable()->executable()->module().input_output_alias_config(); tensorflow::profiler::TraceMe hlo_module_activity( [&] { @@ -653,12 +680,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { }, tensorflow::profiler::TraceMeLevel::kInfo); + xla::StatusOr> variable_infos = GatherVariableInfo( + ctx, *closure.compilation_result(), closure.num_constant_args()); + OP_REQUIRES_OK(ctx, variable_infos.status()); + OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos))); OP_REQUIRES_OK( ctx, launch_context.PopulateOutputs( - ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), + ctx, closure.compilation_result(), execution_output->ConsumeResult(), /*missing_ctx_input_prefix=*/closure.num_constant_args(), - input_output_alias, closure.resource_var_snapshots())); + absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs)); } XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index afaee614f02..50813859603 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -50,35 +50,47 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( client, client->backend().memory_allocator(), + client->default_device_ordinal(), /*allocate_xla_tensors=*/true, /*use_multiple_streams=*/metadata.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, result, variable_args, - /*missing_ctx_input_prefix=*/0); + std::map snapshot_ptrs; + for (auto& p : variable_args) { + snapshot_ptrs.emplace(p.first, + p.second.has_value() ? &p.second.value() : nullptr); + } + + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + xla::StatusOr> execution_inputs = + launch_context.PopulateInputs(ctx, result, snapshot_ptrs, + /*missing_ctx_input_prefix=*/0, + input_output_alias); + TF_RETURN_IF_ERROR(execution_inputs.status()); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; TF_RET_CHECK(stream); VLOG(2) << "Executing computation: " << name(); - for (const xla::ShapedBuffer* arg : launch_context.arguments()) { - VLOG(2) << name() << ": " << *arg; - } xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::StatusOr run_result = - executable->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result = + executable->Run(execution_inputs.ConsumeValueOrDie(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - - const xla::HloInputOutputAliasConfig& input_output_alias = - executable->executable()->module().input_output_alias_config(); + xla::ExecutionOutput execution_output = run_result.ConsumeValueOrDie(); + xla::StatusOr> variable_infos = + GatherVariableInfo(ctx, *result, 0); + TF_RETURN_IF_ERROR(variable_infos.status()); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variable_infos))); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( - ctx, result, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0, input_output_alias, variable_args)); + ctx, result, execution_output.ConsumeResult(), + /*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos), + input_output_alias, snapshot_ptrs)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index 8126059262b..f0555ae32e5 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -59,11 +59,13 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) { return Status::OK(); })); mutex_lock ml(*variable->mu()); - OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, - errors::InvalidArgument( - "Trying to assign variable with wrong dtype. Expected ", - DataTypeString(variable->tensor()->dtype()), " got ", - DataTypeString(dtype_))); + OP_REQUIRES( + context, + !variable->is_initialized || variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_))); variable->is_initialized = true; *variable->tensor() = value; } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 7f107aaef11..41abe86df6e 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -91,29 +91,19 @@ VariableInfo::~VariableInfo() { Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, absl::Span variable_indices, std::vector* result) { - std::vector resource_handles; - absl::c_transform( - variable_indices, std::back_inserter(resource_handles), - [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); }); - - std::vector> variables; - Status s = LookupResources(ctx, resource_handles, &variables); - if (!s.ok()) { - errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage); - return s; - } - result->clear(); result->reserve(variable_indices.size()); - for (int i = 0; i < variable_indices.size(); i++) { - // *Release* the variable because we're going to unref it later in - // ~VariableInfo. - Var* variable = variables[i].release(); - int input_idx = variable_indices[i]; - std::string var_name = HandleFromInput(ctx, input_idx).name(); - result->emplace_back(input_idx, var_name, variable); + for (int var_idx : variable_indices) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, var_idx); + TF_RETURN_IF_ERROR( + LookupOrCreateResource(ctx, handle, &variable, [&](Var** ptr) { + // This var is uninitialized for now. + *ptr = new Var(DT_INVALID); + return Status::OK(); + })); + result->emplace_back(var_idx, handle.name(), variable); } - return Status::OK(); } @@ -176,24 +166,43 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors, bool use_multiple_streams) + int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams) : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors), - use_multiple_streams_(use_multiple_streams) { + use_multiple_streams_(use_multiple_streams), + device_ordinal_(device_ordinal) { if (use_multiple_streams_) { CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must " "be allocating XLA tensors!"; } } -void XlaComputationLaunchContext::PopulateInputs( +// Fills in `execution_input` with `buffer` for `index`. +static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input, + xla::ShapeIndex index, + se::DeviceMemoryBase& buffer, + bool donate_buffer, int device_ordinal, + se::DeviceMemoryAllocator* allocator) { + xla::MaybeOwningDeviceMemory* in_buffer = + execution_input.MutableBuffer(index); + if (donate_buffer) { + *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator); + buffer = se::DeviceMemoryBase(); + } else { + *in_buffer = buffer; + } +} + +xla::StatusOr> +XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, - const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) { - // Build ShapedBuffers that point directly to the Tensor buffers. - arg_ptrs_ = - std::vector(compilation_result->xla_input_shapes.size()); + const std::map& resource_vars, + int missing_ctx_input_prefix, + const xla::HloInputOutputAliasConfig& input_output_alias) { + std::vector arguments; + arguments.reserve(compilation_result->xla_input_shapes.size()); xla::TransferManager* transfer_manager = client_->backend().transfer_manager(); @@ -201,10 +210,28 @@ void XlaComputationLaunchContext::PopulateInputs( int arg_num = compilation_result->input_mapping[i]; CHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = compilation_result->xla_input_shapes[i]; - const Tensor* t = variables.count(arg_num) - ? &(variables.at(arg_num).value()) + const xla::Shape& device_shape = + transfer_manager->HostShapeToDeviceShape(shape); + + bool is_resource_variable = resource_vars.count(arg_num); + bool is_updated_resource_variable = + is_resource_variable && + absl::c_any_of(compilation_result->resource_updates, + [&](const XlaCompiler::ResourceUpdate& update) { + return update.input_index == i && update.modified; + }); + + const Tensor* t = is_resource_variable + ? resource_vars.at(arg_num) : &(ctx->input(arg_num - missing_ctx_input_prefix)); CHECK(t); + bool donate_buffer = + t->RefCountIsOne() && is_updated_resource_variable && + input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{}); + VLOG(3) << "Processing input: " << i + << "; is_resource_variable=" << is_resource_variable + << "; is_updated_resource_variable=" << is_updated_resource_variable + << "; donate_buffer=" << donate_buffer; if (use_multiple_streams_) { CHECK(ctx->op_device_context() && ctx->op_device_context()->stream()) @@ -215,23 +242,28 @@ void XlaComputationLaunchContext::PopulateInputs( ctx->op_device_context()->stream()); } - if (xla::Shape::Equal().MinorToMajorOnlyInLayout()( - shape, transfer_manager->HostShapeToDeviceShape(shape))) { + arguments.emplace_back(device_shape, shape); + xla::ExecutionInput& execution_input = arguments.back(); + if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) { se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_.emplace_back( - /*on_host_shape=*/shape, /*on_device_shape=*/shape, - client_->platform(), client_->default_device_ordinal()); - arg_buffers_.back().set_buffer(dmem, /*index=*/{}); - arg_ptrs_[i] = &arg_buffers_.back(); + PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem, + donate_buffer, device_ordinal_, + xla_allocator_); } else { - const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); - arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); + xla_tensor->shaped_buffer().buffers().ForEachMutableElement( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + PopulateExecutionInputBuffer(execution_input, index, *buffer, + donate_buffer, device_ordinal_, + xla_allocator_); + }); } } + return std::move(arguments); } -// Construct the tensor for given type and buffer. +// Construct the tensor for the given type and buffer. static Tensor MakeTensor(DataType dtype, const TensorShape& shape, se::DeviceMemoryBase buffer, Allocator* allocator) { size_t expected_size = shape.num_elements() * DataTypeSize(dtype); @@ -247,28 +279,26 @@ static Tensor GetOrCreateTensorForOutput( int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, absl::Span input_mapping, - const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype, - const TensorShape& output_shape, se::DeviceMemoryBase output_buffer, - Allocator* output_allocator) { + const std::map& resource_vars_snapshots, + DataType output_dtype, const TensorShape& output_shape, + se::DeviceMemoryBase output_buffer, Allocator* output_allocator) { xla::ShapeIndex output_index = input_output_alias.shape().IsTuple() ? xla::ShapeIndex({output_num}) : xla::ShapeIndex({}); + CHECK(input_output_alias.shape().IsTuple() || output_num == 0); if (absl::optional alias = input_output_alias.GetAliasedParameter(output_index)) { + VLOG(3) << "Found alias: " << alias->ToString(); int tf_param = input_mapping[alias->parameter_number] - missing_ctx_input_prefix; - const Tensor* input_tensor = &ctx->input(tf_param); - - // If input tensor is a resource variable, alias to the snapshot we took at - // entry time. - if (input_tensor->dtype() == DT_RESOURCE) { - const absl::optional& v = - resource_var_snapshots.at(missing_ctx_input_prefix + tf_param); - CHECK(v.has_value()); - return *v; + const Tensor input_tensor = + ctx->input(tf_param).dtype() != DT_RESOURCE + ? ctx->input(tf_param) + : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param); + if (output_buffer.opaque() == input_tensor.data()) { + return input_tensor; } - return *input_tensor; } return MakeTensor(output_dtype, output_shape, output_buffer, output_allocator); @@ -291,12 +321,10 @@ static Status SetOutputForConstant( OpKernelContext* ctx, se::Stream* stream, const XlaCompiler::CompilationResult* compilation_result, int output_num) { CHECK(compilation_result->outputs[output_num].is_constant); - // Output is a constant. const Tensor& const_tensor = compilation_result->outputs[output_num].constant_value; Tensor* output_tensor; - const size_t total_bytes = const_tensor.TotalBytes(); - if (stream && total_bytes > 0) { + if (stream && const_tensor.TotalBytes() > 0) { // Copy host -> device. (Empty tensors don't have backing buffers.) // Manually allocate memory using an XlaTensorBuffer so we can allocate // as much memory as the device requires (as given by @@ -335,52 +363,55 @@ static Status SetOutputForConstant( return Status::OK(); } -// Creates a list of updates resource variables. -static xla::StatusOr> GatherVariableInfo( - OpKernelContext* ctx, - const XlaCompiler::CompilationResult* compilation_result, - int missing_ctx_input_prefix) { - std::vector variable_infos; - variable_infos.reserve(compilation_result->resource_updates.size()); +static xla::StatusOr GetOrCreateResourceVar( + OpKernelContext* ctx, const ResourceHandle& handle, + const XlaCompiler::ResourceUpdate& write) { + Var* variable = nullptr; + TF_RETURN_IF_ERROR( + LookupOrCreateResource(ctx, handle, &variable, [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); + return variable; +} - for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { +xla::StatusOr> GatherVariableInfo( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult& compilation_result, + int missing_ctx_input_prefix) { + std::vector out; + out.reserve(compilation_result.resource_updates.size()); + for (int i = 0; i < compilation_result.resource_updates.size(); ++i) { const XlaCompiler::ResourceUpdate& write = - compilation_result->resource_updates[i]; + compilation_result.resource_updates[i]; int actual_input_index = write.input_index - missing_ctx_input_prefix; if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { return errors::Internal("Invalid input index for variable write."); } - // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, - // not a Tensor. - Var* variable = nullptr; const ResourceHandle handle = HandleFromInput(ctx, actual_input_index); - TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, handle, &variable, - [&write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); - variable_infos.emplace_back(actual_input_index, handle.name(), variable); + TF_ASSIGN_OR_RETURN(Var * variable, + GetOrCreateResourceVar(ctx, handle, write)); + out.emplace_back(actual_input_index, handle.name(), variable); } - return variable_infos; + return std::move(out); } Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, ScopedShapedBuffer output, int missing_ctx_input_prefix, + absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const ResourceVarsSnapshot& resource_var_snapshots) { + const std::map& resource_vars) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; Allocator* allocator = ctx->device()->GetAllocator({}); // Computation output should always be a tuple. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString(); - VLOG(2) << "Result tuple shape (on device): " - << output.on_device_shape().DebugString(); - } + VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString(); + VLOG(2) << "Result tuple shape (on device): " + << output.on_device_shape().DebugString(); CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size()); // If the on-host-shape isn't a tuple, create a new single-element tuple @@ -438,8 +469,8 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < ctx->num_outputs(); ++i) { const TensorShape& shape = output_tensor_shapes[i]; const DataType& type = compilation_result->outputs[i].type; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " - << DataTypeString(type); + VLOG(2) << "Populating output for retval " << i << " shape " + << shape.DebugString() << " type " << DataTypeString(type); if (type == DT_VARIANT) { return errors::Unimplemented( "Support for TensorList crossing the XLA/TF boundary " @@ -467,30 +498,37 @@ Status XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); Tensor output_tensor = GetOrCreateTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, - compilation_result->input_mapping, resource_var_snapshots, + compilation_result->input_mapping, resource_vars, ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(se::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } + output.set_buffer(se::OwningDeviceMemory(), {output_num}); ++output_num; } - - if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString(); - } } - // Apply variable updates, if any. - VLOG(2) << "Applying variable updates"; - TF_ASSIGN_OR_RETURN( - std::vector variable_infos, - GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix)); - TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); + // input_index -> index into variable_infos. + absl::flat_hash_map variable_info_lookup; + for (int i = 0; i < variable_infos.size(); i++) { + variable_info_lookup.emplace(variable_infos[i].index(), i); + } + // Apply variable updates, if any. for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { const XlaCompiler::ResourceUpdate& write = compilation_result->resource_updates[i]; - if (variable_infos[i].var()->tensor()->dtype() != write.type) { + int actual_input_index = write.input_index - missing_ctx_input_prefix; + CHECK_GE(actual_input_index, 0); + CHECK_LT(actual_input_index, ctx->num_inputs()); + Var* var = variable_infos[variable_info_lookup[actual_input_index]].var(); + CHECK(var); + + VLOG(2) << "Updating variable #" << i + << " at input index: " << actual_input_index << " with shape " + << write.shape.DebugString() << "; variable tensor has shape: " + << var->tensor()->shape().DebugString(); + + if (var->is_initialized && var->tensor()->dtype() != write.type) { return errors::Internal("Mismatched type in variable write"); } @@ -504,14 +542,14 @@ Status XlaComputationLaunchContext::PopulateOutputs( } } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); - output.set_buffer(se::OwningDeviceMemory(), {output_num}); output_tensor = GetOrCreateTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, - compilation_result->input_mapping, resource_var_snapshots, write.type, + compilation_result->input_mapping, resource_vars, write.type, write.shape, buffer, allocator); } - *variable_infos[i].var()->tensor() = output_tensor; - variable_infos[i].var()->is_initialized |= write.modified; + output.set_buffer(se::OwningDeviceMemory(), {output_num}); + var->is_initialized |= write.modified; + *var->tensor() = output_tensor; ++output_num; } return Status::OK(); @@ -562,7 +600,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( arg.name = std::string(variable.name()); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = XlaResource::kVariable; - if (variable.var()) { + if (variable.var() && variable.var()->is_initialized) { const Tensor* value = variable.var()->tensor(); arg.type = value->dtype(); arg.shape = value->shape(); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 92b6c4c8a08..b34b3059a4f 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -81,6 +81,12 @@ class VariableInfo { bool lock_held_ = false; }; +// Creates a list of updated resource variables. +xla::StatusOr> GatherVariableInfo( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult& compilation_result, + int missing_ctx_input_prefix); + // Takes a snapshot of the values of resource variable arguments, whose indices // are specified in `variable_indices` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is @@ -124,7 +130,7 @@ class XlaComputationLaunchContext { // objects. XlaComputationLaunchContext(xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors, + int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams); // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch @@ -142,10 +148,12 @@ class XlaComputationLaunchContext { // missing and adjusts input indices accordingly. All elements in kernel's // input_mapping must be greater than or equal to `missing_ctx_input_prefix` // (in other words, no inputs actually required by the kernel can be missing). - void PopulateInputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* compilation_result, - const ResourceVarsSnapshot& variables, - int missing_ctx_input_prefix); + xla::StatusOr> PopulateInputs( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, + const std::map& resource_vars, + int missing_ctx_input_prefix, + const xla::HloInputOutputAliasConfig& input_output_alias); // Given the XLA output in `output`, populate all outputs of `ctx`. Also // writes out the resource variable updates. @@ -161,20 +169,16 @@ class XlaComputationLaunchContext { OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, + absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const ResourceVarsSnapshot& resource_var_snapshots); - - // Return the argument list. Only valid after PopulateInputs() has been - // called. - const std::vector& arguments() const { return arg_ptrs_; } + const std::map& resource_vars); private: xla::LocalClient* client_; se::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; bool use_multiple_streams_; - std::deque arg_buffers_; - std::vector arg_ptrs_; + int device_ordinal_; }; // A simple TensorBuffer implementation that allows us to create Tensors that diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index d55f84863e9..bd7a6ec2279 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -403,6 +404,69 @@ class DefFunctionTest(test.TestCase): self.assertEqual(inner_retracings, 1) + def testUpdateVariable(self): + v = variables.Variable(3.1) + + @def_function.function(experimental_compile=True) + def update_var(a, b): + v.assign_add(a * b) + + update_var(constant_op.constant(0.7), constant_op.constant(0.6)) + self.assertAllClose(v, 3.52) + + def testUpdateVariableVector(self): + v = variables.Variable([3.1, 3.1]) + + @def_function.function(experimental_compile=True) + def update_var(a, b): + v.assign_add(a * b) + + update_var( + constant_op.constant([0.7, 0.7]), constant_op.constant([0.6, 0.6])) + self.assertAllClose(v, [3.52, 3.52]) + + def testUpdateVariableInClass(self): + + class C(object): + + @def_function.function(experimental_compile=True) + def update_var(self, a, b): + if not hasattr(self, 'v'): + self.v = variables.Variable(3.1) + self.v.assign_add(a * b) + + c = C() + + @def_function.function + def outer(): + c.update_var(constant_op.constant(0.7), constant_op.constant(0.6)) + + outer() + self.assertAllClose(c.v, 3.52) + + def testUpdateVariableMultipleOutputs(self): + v = variables.Variable(3.1) + + @def_function.function(experimental_compile=True) + def update_var(a, b): + v.assign_add(a * b) + return a * b + v + + out = update_var(constant_op.constant(0.7), constant_op.constant(0.6)) + self.assertAllClose(v, 3.52) + self.assertAllClose(out, 3.94) + + def testReturnIdentity(self): + + @def_function.function(experimental_compile=True) + def f(a, b): + return (a, b) + + a = constant_op.constant([0.7]) + b = constant_op.constant([0.6]) + + f(a, b) + if __name__ == '__main__': ops.enable_eager_execution()