Switch FunctionLibraryDefinition to share FunctionDef storage with other instances.
We frequently waste memory during graph construction and optimization by copying FunctionLibraryDefinition objects, which entails a copy of all the FunctionDefs in the library. Since all FunctionDefs (and related op registration data) in the library are immutable after it is constructed, we can share them with other library instances. PiperOrigin-RevId: 258254464
This commit is contained in:
parent
b4f842384a
commit
00d3da0add
@ -1054,9 +1054,7 @@ FunctionLibraryDefinition::FunctionLibraryDefinition(
|
|||||||
const FunctionLibraryDefinition& other)
|
const FunctionLibraryDefinition& other)
|
||||||
: default_registry_(other.default_registry_) {
|
: default_registry_(other.default_registry_) {
|
||||||
tf_shared_lock l(other.mu_);
|
tf_shared_lock l(other.mu_);
|
||||||
for (const auto& it : other.function_defs_) {
|
function_defs_ = other.function_defs_;
|
||||||
TF_CHECK_OK(AddFunctionDef(it.second->fdef));
|
|
||||||
}
|
|
||||||
func_grad_ = other.func_grad_;
|
func_grad_ = other.func_grad_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1084,16 +1082,16 @@ bool FunctionLibraryDefinition::Contains(const string& func) const {
|
|||||||
|
|
||||||
const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
|
const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
return FindHelper(func);
|
return &(FindHelper(func)->fdef);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FunctionDef* FunctionLibraryDefinition::FindHelper(
|
std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
|
||||||
const string& func) const {
|
FunctionLibraryDefinition::FindHelper(const string& func) const {
|
||||||
auto iter = function_defs_.find(func);
|
auto iter = function_defs_.find(func);
|
||||||
if (iter == function_defs_.end()) {
|
if (iter == function_defs_.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
} else {
|
} else {
|
||||||
return &iter->second->fdef;
|
return iter->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1106,10 +1104,10 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
|
|||||||
Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
|
Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
|
||||||
bool* added) {
|
bool* added) {
|
||||||
*added = false;
|
*added = false;
|
||||||
std::unique_ptr<FunctionDefAndOpRegistration>* entry =
|
std::shared_ptr<FunctionDefAndOpRegistration>& entry =
|
||||||
&function_defs_[fdef.signature().name()];
|
function_defs_[fdef.signature().name()];
|
||||||
if (*entry != nullptr) {
|
if (entry) {
|
||||||
if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
|
if (!FunctionDefsEqual(entry->fdef, fdef)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Cannot add function '", fdef.signature().name(),
|
"Cannot add function '", fdef.signature().name(),
|
||||||
"' because a different function with the same name already "
|
"' because a different function with the same name already "
|
||||||
@ -1124,11 +1122,74 @@ Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
|
|||||||
"Cannot add function '", fdef.signature().name(),
|
"Cannot add function '", fdef.signature().name(),
|
||||||
"' because an op with the same name already exists.");
|
"' because an op with the same name already exists.");
|
||||||
}
|
}
|
||||||
entry->reset(new FunctionDefAndOpRegistration(fdef));
|
entry = std::make_shared<FunctionDefAndOpRegistration>(fdef);
|
||||||
*added = true;
|
*added = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status FunctionLibraryDefinition::AddHelper(
|
||||||
|
std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
|
||||||
|
*added = false;
|
||||||
|
std::shared_ptr<FunctionDefAndOpRegistration>& entry =
|
||||||
|
function_defs_[registration->fdef.signature().name()];
|
||||||
|
if (entry) {
|
||||||
|
if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot add function '", registration->fdef.signature().name(),
|
||||||
|
"' because a different function with the same name already "
|
||||||
|
"exists.");
|
||||||
|
}
|
||||||
|
// Ignore duplicate FunctionDefs.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
const OpDef* op_def;
|
||||||
|
if (default_registry_
|
||||||
|
->LookUpOpDef(registration->fdef.signature().name(), &op_def)
|
||||||
|
.ok()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot add function '", registration->fdef.signature().name(),
|
||||||
|
"' because an op with the same name already exists.");
|
||||||
|
}
|
||||||
|
entry = std::move(registration);
|
||||||
|
*added = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FunctionLibraryDefinition::CopyFunctionDefFrom(
|
||||||
|
const string& func, const FunctionLibraryDefinition& other) {
|
||||||
|
if (default_registry_ != other.default_registry_) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot copy function '", func,
|
||||||
|
"' because CopyFunctionDefFrom() requires that both libraries have the "
|
||||||
|
"same default registry.");
|
||||||
|
}
|
||||||
|
std::shared_ptr<FunctionDefAndOpRegistration> function_def;
|
||||||
|
{
|
||||||
|
tf_shared_lock l(other.mu_);
|
||||||
|
function_def = other.FindHelper(func);
|
||||||
|
}
|
||||||
|
if (!function_def) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot copy function '", func,
|
||||||
|
"' because no function with that name exists in the other library.");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
|
||||||
|
if (entry) {
|
||||||
|
if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot copy function '", func,
|
||||||
|
"' because a different function with the same name already "
|
||||||
|
"exists.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
entry = std::move(function_def);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
|
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
bool added;
|
bool added;
|
||||||
@ -1167,7 +1228,7 @@ Status FunctionLibraryDefinition::AddLibrary(
|
|||||||
Status s;
|
Status s;
|
||||||
bool added;
|
bool added;
|
||||||
for (auto iter : clone.function_defs_) {
|
for (auto iter : clone.function_defs_) {
|
||||||
s = AddFunctionDefHelper(iter.second->fdef, &added);
|
s = AddHelper(iter.second, &added);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
Remove(funcs, funcs_with_grads);
|
Remove(funcs, funcs_with_grads);
|
||||||
return s;
|
return s;
|
||||||
@ -1333,9 +1394,9 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
|
|||||||
// function's attrs to see if noinline is specified. Otherwise,
|
// function's attrs to see if noinline is specified. Otherwise,
|
||||||
// uses func's attrs.
|
// uses func's attrs.
|
||||||
if (!grad_name.empty()) {
|
if (!grad_name.empty()) {
|
||||||
return FindHelper(grad_name);
|
return &(FindHelper(grad_name)->fdef);
|
||||||
}
|
}
|
||||||
return FindHelper(func_name);
|
return &(FindHelper(func_name)->fdef);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1492,12 +1553,10 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
|||||||
FunctionDefLibrary());
|
FunctionDefLibrary());
|
||||||
|
|
||||||
for (const string& func_name : reachable_funcs) {
|
for (const string& func_name : reachable_funcs) {
|
||||||
const FunctionDef* func = flib.Find(func_name);
|
// This should never fail, because we copy functions from a valid flib and
|
||||||
DCHECK_NE(func, nullptr);
|
// use the same default registry.
|
||||||
// That should never fail, because we copy functions from valid flib and use
|
Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
|
||||||
// the same default registry.
|
TF_DCHECK_OK(added);
|
||||||
const Status added = reachable_flib.AddFunctionDef(*func);
|
|
||||||
DCHECK(added.ok());
|
|
||||||
|
|
||||||
const string grad_func_name = flib.FindGradient(func_name);
|
const string grad_func_name = flib.FindGradient(func_name);
|
||||||
if (!grad_func_name.empty()) {
|
if (!grad_func_name.empty()) {
|
||||||
@ -1506,7 +1565,7 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
|||||||
grad.set_gradient_func(grad_func_name);
|
grad.set_gradient_func(grad_func_name);
|
||||||
// It can only fail if function already has a gradient function.
|
// It can only fail if function already has a gradient function.
|
||||||
const Status added_grad = reachable_flib.AddGradientDef(grad);
|
const Status added_grad = reachable_flib.AddGradientDef(grad);
|
||||||
DCHECK(added_grad.ok());
|
TF_DCHECK_OK(added_grad);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -442,21 +442,35 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
|
FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
|
||||||
FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
|
FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
|
||||||
|
|
||||||
|
// Copies the function named `func` from `other` to this
|
||||||
|
// FunctionLibraryDefinition.
|
||||||
|
// REQUIRES: `this->default_registry() == other.default_registry()`.
|
||||||
|
// Returns OK on success, or error otherwise. This is a no-op if a function
|
||||||
|
// name `func` already exists in this function library, and has the same
|
||||||
|
// implementation as in `other`. If the implementations conflict, an invalid
|
||||||
|
// argument error is returned.
|
||||||
|
Status CopyFunctionDefFrom(const string& func,
|
||||||
|
const FunctionLibraryDefinition& other)
|
||||||
|
LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Shape inference for functions is handled separately by ShapeRefiner.
|
// Shape inference for functions is handled separately by ShapeRefiner.
|
||||||
|
|
||||||
struct FunctionDefAndOpRegistration {
|
struct FunctionDefAndOpRegistration {
|
||||||
explicit FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
|
explicit FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
|
||||||
|
|
||||||
FunctionDef fdef;
|
const FunctionDef fdef;
|
||||||
OpRegistrationData op_registration_data;
|
const OpRegistrationData op_registration_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
const FunctionDef* FindHelper(const string& func) const
|
std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
|
||||||
SHARED_LOCKS_REQUIRED(mu_);
|
const string& func) const SHARED_LOCKS_REQUIRED(mu_);
|
||||||
string FindGradientHelper(const string& func) const
|
string FindGradientHelper(const string& func) const
|
||||||
SHARED_LOCKS_REQUIRED(mu_);
|
SHARED_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
|
||||||
|
bool* added) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Same as AddFunctionDef/AddGradientDef except these methods set
|
// Same as AddFunctionDef/AddGradientDef except these methods set
|
||||||
// `added` to true if the `fdef`/`grad` were actually added to this.
|
// `added` to true if the `fdef`/`grad` were actually added to this.
|
||||||
Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added)
|
Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added)
|
||||||
@ -485,7 +499,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
const OpRegistryInterface* const default_registry_;
|
const OpRegistryInterface* const default_registry_;
|
||||||
gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
|
gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
|
||||||
function_defs_ GUARDED_BY(mu_);
|
function_defs_ GUARDED_BY(mu_);
|
||||||
gtl::FlatMap<string, string> func_grad_ GUARDED_BY(mu_);
|
gtl::FlatMap<string, string> func_grad_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user