Switch FunctionLibraryDefinition to share FunctionDef storage with other instances.
We frequently waste memory during graph construction and optimization by copying FunctionLibraryDefinition objects, which entails a copy of all the FunctionDefs in the library. Since all FunctionDefs (and related op registration data) in the library are immutable after it is constructed, we can share them with other library instances. PiperOrigin-RevId: 258254464
This commit is contained in:
parent
b4f842384a
commit
00d3da0add
@ -1054,9 +1054,7 @@ FunctionLibraryDefinition::FunctionLibraryDefinition(
|
||||
const FunctionLibraryDefinition& other)
|
||||
: default_registry_(other.default_registry_) {
|
||||
tf_shared_lock l(other.mu_);
|
||||
for (const auto& it : other.function_defs_) {
|
||||
TF_CHECK_OK(AddFunctionDef(it.second->fdef));
|
||||
}
|
||||
function_defs_ = other.function_defs_;
|
||||
func_grad_ = other.func_grad_;
|
||||
}
|
||||
|
||||
@ -1084,16 +1082,16 @@ bool FunctionLibraryDefinition::Contains(const string& func) const {
|
||||
|
||||
const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
|
||||
tf_shared_lock l(mu_);
|
||||
return FindHelper(func);
|
||||
return &(FindHelper(func)->fdef);
|
||||
}
|
||||
|
||||
const FunctionDef* FunctionLibraryDefinition::FindHelper(
|
||||
const string& func) const {
|
||||
std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
|
||||
FunctionLibraryDefinition::FindHelper(const string& func) const {
|
||||
auto iter = function_defs_.find(func);
|
||||
if (iter == function_defs_.end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return &iter->second->fdef;
|
||||
return iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1106,10 +1104,10 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
|
||||
Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
|
||||
bool* added) {
|
||||
*added = false;
|
||||
std::unique_ptr<FunctionDefAndOpRegistration>* entry =
|
||||
&function_defs_[fdef.signature().name()];
|
||||
if (*entry != nullptr) {
|
||||
if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
|
||||
std::shared_ptr<FunctionDefAndOpRegistration>& entry =
|
||||
function_defs_[fdef.signature().name()];
|
||||
if (entry) {
|
||||
if (!FunctionDefsEqual(entry->fdef, fdef)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add function '", fdef.signature().name(),
|
||||
"' because a different function with the same name already "
|
||||
@ -1124,11 +1122,74 @@ Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
|
||||
"Cannot add function '", fdef.signature().name(),
|
||||
"' because an op with the same name already exists.");
|
||||
}
|
||||
entry->reset(new FunctionDefAndOpRegistration(fdef));
|
||||
entry = std::make_shared<FunctionDefAndOpRegistration>(fdef);
|
||||
*added = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryDefinition::AddHelper(
|
||||
std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
|
||||
*added = false;
|
||||
std::shared_ptr<FunctionDefAndOpRegistration>& entry =
|
||||
function_defs_[registration->fdef.signature().name()];
|
||||
if (entry) {
|
||||
if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add function '", registration->fdef.signature().name(),
|
||||
"' because a different function with the same name already "
|
||||
"exists.");
|
||||
}
|
||||
// Ignore duplicate FunctionDefs.
|
||||
return Status::OK();
|
||||
}
|
||||
const OpDef* op_def;
|
||||
if (default_registry_
|
||||
->LookUpOpDef(registration->fdef.signature().name(), &op_def)
|
||||
.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add function '", registration->fdef.signature().name(),
|
||||
"' because an op with the same name already exists.");
|
||||
}
|
||||
entry = std::move(registration);
|
||||
*added = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryDefinition::CopyFunctionDefFrom(
|
||||
const string& func, const FunctionLibraryDefinition& other) {
|
||||
if (default_registry_ != other.default_registry_) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot copy function '", func,
|
||||
"' because CopyFunctionDefFrom() requires that both libraries have the "
|
||||
"same default registry.");
|
||||
}
|
||||
std::shared_ptr<FunctionDefAndOpRegistration> function_def;
|
||||
{
|
||||
tf_shared_lock l(other.mu_);
|
||||
function_def = other.FindHelper(func);
|
||||
}
|
||||
if (!function_def) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot copy function '", func,
|
||||
"' because no function with that name exists in the other library.");
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
|
||||
if (entry) {
|
||||
if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot copy function '", func,
|
||||
"' because a different function with the same name already "
|
||||
"exists.");
|
||||
}
|
||||
} else {
|
||||
entry = std::move(function_def);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
|
||||
mutex_lock l(mu_);
|
||||
bool added;
|
||||
@ -1167,7 +1228,7 @@ Status FunctionLibraryDefinition::AddLibrary(
|
||||
Status s;
|
||||
bool added;
|
||||
for (auto iter : clone.function_defs_) {
|
||||
s = AddFunctionDefHelper(iter.second->fdef, &added);
|
||||
s = AddHelper(iter.second, &added);
|
||||
if (!s.ok()) {
|
||||
Remove(funcs, funcs_with_grads);
|
||||
return s;
|
||||
@ -1333,9 +1394,9 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
|
||||
// function's attrs to see if noinline is specified. Otherwise,
|
||||
// uses func's attrs.
|
||||
if (!grad_name.empty()) {
|
||||
return FindHelper(grad_name);
|
||||
return &(FindHelper(grad_name)->fdef);
|
||||
}
|
||||
return FindHelper(func_name);
|
||||
return &(FindHelper(func_name)->fdef);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1492,12 +1553,10 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
||||
FunctionDefLibrary());
|
||||
|
||||
for (const string& func_name : reachable_funcs) {
|
||||
const FunctionDef* func = flib.Find(func_name);
|
||||
DCHECK_NE(func, nullptr);
|
||||
// That should never fail, because we copy functions from valid flib and use
|
||||
// the same default registry.
|
||||
const Status added = reachable_flib.AddFunctionDef(*func);
|
||||
DCHECK(added.ok());
|
||||
// This should never fail, because we copy functions from a valid flib and
|
||||
// use the same default registry.
|
||||
Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
|
||||
TF_DCHECK_OK(added);
|
||||
|
||||
const string grad_func_name = flib.FindGradient(func_name);
|
||||
if (!grad_func_name.empty()) {
|
||||
@ -1506,7 +1565,7 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
||||
grad.set_gradient_func(grad_func_name);
|
||||
// It can only fail if function already has a gradient function.
|
||||
const Status added_grad = reachable_flib.AddGradientDef(grad);
|
||||
DCHECK(added_grad.ok());
|
||||
TF_DCHECK_OK(added_grad);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -442,21 +442,35 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
|
||||
FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
|
||||
|
||||
// Copies the function named `func` from `other` to this
|
||||
// FunctionLibraryDefinition.
|
||||
// REQUIRES: `this->default_registry() == other.default_registry()`.
|
||||
// Returns OK on success, or error otherwise. This is a no-op if a function
|
||||
// name `func` already exists in this function library, and has the same
|
||||
// implementation as in `other`. If the implementations conflict, an invalid
|
||||
// argument error is returned.
|
||||
Status CopyFunctionDefFrom(const string& func,
|
||||
const FunctionLibraryDefinition& other)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
// Shape inference for functions is handled separately by ShapeRefiner.
|
||||
|
||||
struct FunctionDefAndOpRegistration {
|
||||
explicit FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
|
||||
|
||||
FunctionDef fdef;
|
||||
OpRegistrationData op_registration_data;
|
||||
const FunctionDef fdef;
|
||||
const OpRegistrationData op_registration_data;
|
||||
};
|
||||
|
||||
const FunctionDef* FindHelper(const string& func) const
|
||||
SHARED_LOCKS_REQUIRED(mu_);
|
||||
std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
|
||||
const string& func) const SHARED_LOCKS_REQUIRED(mu_);
|
||||
string FindGradientHelper(const string& func) const
|
||||
SHARED_LOCKS_REQUIRED(mu_);
|
||||
|
||||
Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
|
||||
bool* added) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Same as AddFunctionDef/AddGradientDef except these methods set
|
||||
// `added` to true if the `fdef`/`grad` were actually added to this.
|
||||
Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added)
|
||||
@ -485,7 +499,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
|
||||
mutable mutex mu_;
|
||||
const OpRegistryInterface* const default_registry_;
|
||||
gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
|
||||
gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
|
||||
function_defs_ GUARDED_BY(mu_);
|
||||
gtl::FlatMap<string, string> func_grad_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user