From c5e13d84ca6da4bfc6a190c649f35e7c15154121 Mon Sep 17 00:00:00 2001 From: Jing Pu Date: Wed, 26 Aug 2020 23:25:02 -0700 Subject: [PATCH] Add a "Clear" method in FunctionLibraryDefinition. PiperOrigin-RevId: 328682257 Change-Id: Ib1db2ed042f30218f7d6ced101561d07774300be --- tensorflow/core/framework/function.cc | 6 ++++++ tensorflow/core/framework/function.h | 3 +++ tensorflow/core/framework/function_test.cc | 10 ++++++++++ 3 files changed, 19 insertions(+) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index ebf06c7d0cd..564290bcb21 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1454,6 +1454,12 @@ Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { return Status::OK(); } +void FunctionLibraryDefinition::Clear() { + mutex_lock l(mu_); + function_defs_.clear(); + func_grad_.clear(); +} + Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); if (i == func_grad_.end()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 3c7c09eee37..3c048161b7d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -403,6 +403,9 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // are no longer in use. Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_); + // Removes all the functions and gradient functions. + void Clear() TF_LOCKS_EXCLUDED(mu_); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index a62acfe571e..38ab8be291d 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -1068,6 +1068,16 @@ TEST(FunctionLibraryDefinitionTest, RemoveFunction) { EXPECT_FALSE(lib_def.Contains("XTimesTwo")); } +TEST(FunctionLibraryDefinitionTest, Clear) { + FunctionLibraryDefinition lib_def(OpRegistry::Global(), {}); + TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); + TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XAddX())); + + lib_def.Clear(); + EXPECT_FALSE(lib_def.Contains("XTimesTwo")); + EXPECT_FALSE(lib_def.Contains("XAddX")); +} + TEST(FunctionLibraryDefinitionTest, AddLibrary) { // Create lib def with single function FunctionDefLibrary proto;