diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index f95288d45cf..9c6704f575f 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -193,12 +193,12 @@ class ForwardAccumulator { // Does not take ownership of `vspace`, which must outlive the // ForwardAccumulator. explicit ForwardAccumulator( - const VSpace* vspace) + const VSpace& vspace) : vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {} virtual ~ForwardAccumulator() { for (auto accumulated : accumulated_gradients_) { - vspace_->DeleteGradient(accumulated.second); + vspace_.DeleteGradient(accumulated.second); } } @@ -276,7 +276,7 @@ class ForwardAccumulator { std::unordered_map accumulated_gradients_; // Not owned; provides operations on Tensors which are currently only // available in language bindings (e.g. Python). - const VSpace* vspace_; + const VSpace& vspace_; // Set temporarily while in the Accumulate method; if backward_tape_ is not // nullptr then we forward op executions to it so Accumulate can compute a // backward pass on its backward function. @@ -865,9 +865,9 @@ ForwardAccumulator::ForwardpropFromTape( sources.reserve(output_tensors.size()); for (const TapeTensor& output_tensor : output_tensors) { // Ownership of `aid` transferred to CallBackwardFunction below. - Gradient* aid = vspace_->Ones(output_tensor); + Gradient* aid = vspace_.Ones(output_tensor); forwardprop_aids.push_back(aid); - int64 aid_id = vspace_->TensorId(aid); + int64 aid_id = vspace_.TensorId(aid); sources.push_back(aid_id); sources_set.insert(aid_id); tape->Watch(aid_id); @@ -875,7 +875,7 @@ ForwardAccumulator::ForwardpropFromTape( std::vector grad; auto delete_grad = gtl::MakeCleanup([&grad, this] { for (Gradient* tensor : grad) { - this->vspace_->DeleteGradient(tensor); + this->vspace_.DeleteGradient(tensor); } }); { @@ -883,7 +883,7 @@ ForwardAccumulator::ForwardpropFromTape( std::unique_ptr> backward_function(backward_function_getter(), backward_function_deleter); - TF_RETURN_IF_ERROR(vspace_->CallBackwardFunction( + TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction( backward_function.get(), unneeded_gradients, forwardprop_aids, &grad)); } @@ -894,11 +894,11 @@ ForwardAccumulator::ForwardpropFromTape( std::unordered_map sources_that_are_targets; for (Gradient* grad_tensor : grad) { if (grad_tensor != nullptr) { - int64 tensor_id = vspace_->TensorId(grad_tensor); + int64 tensor_id = vspace_.TensorId(grad_tensor); targets.push_back(tensor_id); if (sources_set.find(tensor_id) != sources_set.end()) { sources_that_are_targets.emplace( - tensor_id, vspace_->TapeTensorFromGradient(grad_tensor)); + tensor_id, vspace_.TapeTensorFromGradient(grad_tensor)); } } } @@ -911,11 +911,11 @@ ForwardAccumulator::ForwardpropFromTape( Gradient* grad_tensor = grad[target_index]; if (grad_tensor != nullptr && in_grad != nullptr) { // ComputeGradient steals a reference - vspace_->MarkAsResult(in_grad); + vspace_.MarkAsResult(in_grad); } } - return tape->ComputeGradient(*vspace_, targets, sources, + return tape->ComputeGradient(vspace_, targets, sources, sources_that_are_targets, in_grads, out_grads); } @@ -952,7 +952,7 @@ Status ForwardAccumulator::Accumulate( std::vector new_zeros; auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] { for (Gradient* tensor : new_zeros) { - this->vspace_->DeleteGradient(tensor); + this->vspace_.DeleteGradient(tensor); } }); std::vector in_grads; @@ -965,7 +965,7 @@ Status ForwardAccumulator::Accumulate( if (IsDtypeTrainable(input_tensors[target_index].GetDType())) { // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike // GradientTape which uses ones. - Gradient* zero = vspace_->Zeros(input_tensors[target_index]); + Gradient* zero = vspace_.Zeros(input_tensors[target_index]); new_zeros.push_back(zero); in_grads.push_back(zero); } else { @@ -990,7 +990,7 @@ Status ForwardAccumulator::Accumulate( int64 tensor_id = output_tensors[i].GetID(); auto existing = accumulated_gradients_.find(tensor_id); if (existing != accumulated_gradients_.end()) { - vspace_->DeleteGradient(existing->second); + vspace_.DeleteGradient(existing->second); } accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i]; } @@ -1002,14 +1002,14 @@ template void ForwardAccumulator::Watch( int64 tensor_id, Gradient* tangent) { auto existing = accumulated_gradients_.find(tensor_id); - vspace_->MarkAsResult(tangent); + vspace_.MarkAsResult(tangent); if (existing == accumulated_gradients_.end()) { accumulated_gradients_.emplace(tensor_id, tangent); } else { std::array to_aggregate({tangent, existing->second}); // AggregateGradients steals a reference to each of its arguments. We // MarkAsResult on `tangent` above so we don't steal a reference to it. - existing->second = vspace_->AggregateGradients(to_aggregate); + existing->second = vspace_.AggregateGradients(to_aggregate); } } @@ -1018,7 +1018,7 @@ void ForwardAccumulator::DeleteGradient( int64 tensor_id) { auto existing = accumulated_gradients_.find(tensor_id); if (existing != accumulated_gradients_.end()) { - vspace_->DeleteGradient(existing->second); + vspace_.DeleteGradient(existing->second); accumulated_gradients_.erase(existing); } } diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 198ee1e37bf..c80cd5a29f6 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1947,7 +1947,13 @@ PyObject* TFE_Py_ForwardAccumulatorNew() { if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr; TFE_Py_ForwardAccumulator* accumulator = PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type); - accumulator->accumulator = new ForwardAccumulator(py_vspace); + if (py_vspace == nullptr) { + MaybeRaiseExceptionFromStatus( + tensorflow::errors::Internal( + "ForwardAccumulator requires a PyVSpace to be registered."), + nullptr); + } + accumulator->accumulator = new ForwardAccumulator(*py_vspace); Py_INCREF(accumulator); GetAccumulatorSet()->insert( reinterpret_cast(accumulator));