diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index ebf06c7d0cd..564290bcb21 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1454,6 +1454,12 @@ Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { return Status::OK(); } +void FunctionLibraryDefinition::Clear() { + mutex_lock l(mu_); + function_defs_.clear(); + func_grad_.clear(); +} + Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); if (i == func_grad_.end()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 3c7c09eee37..3c048161b7d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -403,6 +403,9 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // are no longer in use. Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_); + // Removes all the functions and gradient functions. + void Clear() TF_LOCKS_EXCLUDED(mu_); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index a62acfe571e..38ab8be291d 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -1068,6 +1068,16 @@ TEST(FunctionLibraryDefinitionTest, RemoveFunction) { EXPECT_FALSE(lib_def.Contains("XTimesTwo")); } +TEST(FunctionLibraryDefinitionTest, Clear) { + FunctionLibraryDefinition lib_def(OpRegistry::Global(), {}); + TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); + TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XAddX())); + + lib_def.Clear(); + EXPECT_FALSE(lib_def.Contains("XTimesTwo")); + EXPECT_FALSE(lib_def.Contains("XAddX")); +} + TEST(FunctionLibraryDefinitionTest, AddLibrary) { // Create lib def with single function FunctionDefLibrary proto;