From 00d3da0add1b525084750130c93b209b1d8bf051 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 15 Jul 2019 15:55:09 -0700 Subject: [PATCH] Switch FunctionLibraryDefinition to share FunctionDef storage with other instances. We frequently waste memory during graph construction and optimization by copying FunctionLibraryDefinition objects, which entails a copy of all the FunctionDefs in the library. Since all FunctionDefs (and related op registration data) in the library are immutable after it is constructed, we can share them with other library instances. PiperOrigin-RevId: 258254464 --- tensorflow/core/framework/function.cc | 103 ++++++++++++++++++++------ tensorflow/core/framework/function.h | 24 ++++-- 2 files changed, 100 insertions(+), 27 deletions(-) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 1cb2b190d05..1643fc0d215 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1054,9 +1054,7 @@ FunctionLibraryDefinition::FunctionLibraryDefinition( const FunctionLibraryDefinition& other) : default_registry_(other.default_registry_) { tf_shared_lock l(other.mu_); - for (const auto& it : other.function_defs_) { - TF_CHECK_OK(AddFunctionDef(it.second->fdef)); - } + function_defs_ = other.function_defs_; func_grad_ = other.func_grad_; } @@ -1084,16 +1082,16 @@ bool FunctionLibraryDefinition::Contains(const string& func) const { const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const { tf_shared_lock l(mu_); - return FindHelper(func); + return &(FindHelper(func)->fdef); } -const FunctionDef* FunctionLibraryDefinition::FindHelper( - const string& func) const { +std::shared_ptr +FunctionLibraryDefinition::FindHelper(const string& func) const { auto iter = function_defs_.find(func); if (iter == function_defs_.end()) { return nullptr; } else { - return &iter->second->fdef; + return iter->second; } } @@ -1106,10 +1104,10 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef, bool* added) { *added = false; - std::unique_ptr* entry = - &function_defs_[fdef.signature().name()]; - if (*entry != nullptr) { - if (!FunctionDefsEqual((*entry)->fdef, fdef)) { + std::shared_ptr& entry = + function_defs_[fdef.signature().name()]; + if (entry) { + if (!FunctionDefsEqual(entry->fdef, fdef)) { return errors::InvalidArgument( "Cannot add function '", fdef.signature().name(), "' because a different function with the same name already " @@ -1124,11 +1122,74 @@ Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef, "Cannot add function '", fdef.signature().name(), "' because an op with the same name already exists."); } - entry->reset(new FunctionDefAndOpRegistration(fdef)); + entry = std::make_shared(fdef); *added = true; return Status::OK(); } +Status FunctionLibraryDefinition::AddHelper( + std::shared_ptr registration, bool* added) { + *added = false; + std::shared_ptr& entry = + function_defs_[registration->fdef.signature().name()]; + if (entry) { + if (!FunctionDefsEqual(entry->fdef, registration->fdef)) { + return errors::InvalidArgument( + "Cannot add function '", registration->fdef.signature().name(), + "' because a different function with the same name already " + "exists."); + } + // Ignore duplicate FunctionDefs. + return Status::OK(); + } + const OpDef* op_def; + if (default_registry_ + ->LookUpOpDef(registration->fdef.signature().name(), &op_def) + .ok()) { + return errors::InvalidArgument( + "Cannot add function '", registration->fdef.signature().name(), + "' because an op with the same name already exists."); + } + entry = std::move(registration); + *added = true; + return Status::OK(); +} + +Status FunctionLibraryDefinition::CopyFunctionDefFrom( + const string& func, const FunctionLibraryDefinition& other) { + if (default_registry_ != other.default_registry_) { + return errors::InvalidArgument( + "Cannot copy function '", func, + "' because CopyFunctionDefFrom() requires that both libraries have the " + "same default registry."); + } + std::shared_ptr function_def; + { + tf_shared_lock l(other.mu_); + function_def = other.FindHelper(func); + } + if (!function_def) { + return errors::InvalidArgument( + "Cannot copy function '", func, + "' because no function with that name exists in the other library."); + } + { + mutex_lock l(mu_); + std::shared_ptr& entry = function_defs_[func]; + if (entry) { + if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) { + return errors::InvalidArgument( + "Cannot copy function '", func, + "' because a different function with the same name already " + "exists."); + } + } else { + entry = std::move(function_def); + } + } + return Status::OK(); +} + Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { mutex_lock l(mu_); bool added; @@ -1167,7 +1228,7 @@ Status FunctionLibraryDefinition::AddLibrary( Status s; bool added; for (auto iter : clone.function_defs_) { - s = AddFunctionDefHelper(iter.second->fdef, &added); + s = AddHelper(iter.second, &added); if (!s.ok()) { Remove(funcs, funcs_with_grads); return s; @@ -1333,9 +1394,9 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( // function's attrs to see if noinline is specified. Otherwise, // uses func's attrs. if (!grad_name.empty()) { - return FindHelper(grad_name); + return &(FindHelper(grad_name)->fdef); } - return FindHelper(func_name); + return &(FindHelper(func_name)->fdef); } } @@ -1492,12 +1553,10 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition( FunctionDefLibrary()); for (const string& func_name : reachable_funcs) { - const FunctionDef* func = flib.Find(func_name); - DCHECK_NE(func, nullptr); - // That should never fail, because we copy functions from valid flib and use - // the same default registry. - const Status added = reachable_flib.AddFunctionDef(*func); - DCHECK(added.ok()); + // This should never fail, because we copy functions from a valid flib and + // use the same default registry. + Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib); + TF_DCHECK_OK(added); const string grad_func_name = flib.FindGradient(func_name); if (!grad_func_name.empty()) { @@ -1506,7 +1565,7 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition( grad.set_gradient_func(grad_func_name); // It can only fail if function already has a gradient function. const Status added_grad = reachable_flib.AddGradientDef(grad); - DCHECK(added_grad.ok()); + TF_DCHECK_OK(added_grad); } } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index ddcd990d9bd..a28d85e33f8 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -442,21 +442,35 @@ class FunctionLibraryDefinition : public OpRegistryInterface { FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const; FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const; + // Copies the function named `func` from `other` to this + // FunctionLibraryDefinition. + // REQUIRES: `this->default_registry() == other.default_registry()`. + // Returns OK on success, or error otherwise. This is a no-op if a function + // name `func` already exists in this function library, and has the same + // implementation as in `other`. If the implementations conflict, an invalid + // argument error is returned. + Status CopyFunctionDefFrom(const string& func, + const FunctionLibraryDefinition& other) + LOCKS_EXCLUDED(mu_); + private: // Shape inference for functions is handled separately by ShapeRefiner. struct FunctionDefAndOpRegistration { explicit FunctionDefAndOpRegistration(const FunctionDef& fdef_in); - FunctionDef fdef; - OpRegistrationData op_registration_data; + const FunctionDef fdef; + const OpRegistrationData op_registration_data; }; - const FunctionDef* FindHelper(const string& func) const - SHARED_LOCKS_REQUIRED(mu_); + std::shared_ptr FindHelper( + const string& func) const SHARED_LOCKS_REQUIRED(mu_); string FindGradientHelper(const string& func) const SHARED_LOCKS_REQUIRED(mu_); + Status AddHelper(std::shared_ptr registration, + bool* added) EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Same as AddFunctionDef/AddGradientDef except these methods set // `added` to true if the `fdef`/`grad` were actually added to this. Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added) @@ -485,7 +499,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { mutable mutex mu_; const OpRegistryInterface* const default_registry_; - gtl::FlatMap> + gtl::FlatMap> function_defs_ GUARDED_BY(mu_); gtl::FlatMap func_grad_ GUARDED_BY(mu_); };