Add a "Clear" method in FunctionLibraryDefinition.

PiperOrigin-RevId: 328682257
Change-Id: Ib1db2ed042f30218f7d6ced101561d07774300be
This commit is contained in:
Jing Pu 2020-08-26 23:25:02 -07:00 committed by TensorFlower Gardener
parent c64796cbba
commit c5e13d84ca
3 changed files with 19 additions and 0 deletions

View File

@ -1454,6 +1454,12 @@ Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
return Status::OK();
}
void FunctionLibraryDefinition::Clear() {
mutex_lock l(mu_);
function_defs_.clear();
func_grad_.clear();
}
Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
const auto& i = func_grad_.find(func);
if (i == func_grad_.end()) {

View File

@ -403,6 +403,9 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// are no longer in use.
Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
// Removes all the functions and gradient functions.
void Clear() TF_LOCKS_EXCLUDED(mu_);
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.

View File

@ -1068,6 +1068,16 @@ TEST(FunctionLibraryDefinitionTest, RemoveFunction) {
EXPECT_FALSE(lib_def.Contains("XTimesTwo"));
}
TEST(FunctionLibraryDefinitionTest, Clear) {
FunctionLibraryDefinition lib_def(OpRegistry::Global(), {});
TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XAddX()));
lib_def.Clear();
EXPECT_FALSE(lib_def.Contains("XTimesTwo"));
EXPECT_FALSE(lib_def.Contains("XAddX"));
}
TEST(FunctionLibraryDefinitionTest, AddLibrary) {
// Create lib def with single function
FunctionDefLibrary proto;