diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 9d43aab5a5d..c6a20bf3ce6 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -724,6 +724,23 @@ string DebugStringWhole(const GraphDef& gdef) { return ret; } +namespace { + +// Returns the name -> attr mapping of fdef's attrs that have a value set. In +// Python, it's possible to access unset attrs, which returns a default value +// and adds an unset attr to the map. +std::map GetSetAttrs(const FunctionDef& fdef) { + std::map set_attrs; + for (auto iter : fdef.attr()) { + if (iter.second.value_case() != AttrValue::VALUE_NOT_SET) { + set_attrs[iter.first] = iter.second; + } + } + return set_attrs; +} + +} // end namespace + bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { // NOTE(skyewm): Using MessageDifferencer would be better here, but that is // currently not included in tensorflow/core/platform/default/protobuf.h, so @@ -736,10 +753,12 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { f2.signature().SerializeToString(&sig2); if (sig1 != sig2) return false; - if (f1.attr().size() != f2.attr().size()) return false; - for (auto iter1 : f1.attr()) { - auto iter2 = f2.attr().find(iter1.first); - if (iter2 == f2.attr().end()) return false; + std::map f1_attrs = GetSetAttrs(f1); + std::map f2_attrs = GetSetAttrs(f2); + if (f1_attrs.size() != f2_attrs.size()) return false; + for (auto iter1 : f1_attrs) { + auto iter2 = f2_attrs.find(iter1.first); + if (iter2 == f2_attrs.end()) return false; if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false; } @@ -883,11 +902,17 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { } Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { - auto& ptr = function_defs_[fdef.signature().name()]; - if (ptr != nullptr) { - return errors::InvalidArgument("Function with name: ", - fdef.signature().name(), - " already exists in function library."); + std::unique_ptr* entry = + &function_defs_[fdef.signature().name()]; + if (*entry != nullptr) { + if (!FunctionDefsEqual((*entry)->fdef, fdef)) { + return errors::InvalidArgument( + "Cannot add function '", fdef.signature().name(), + "' because a different function with the same name already " + "exists."); + } + // Ignore duplicate FunctionDefs + return Status::OK(); } const OpDef* op_def; if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) { @@ -895,19 +920,27 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { "Cannot add function '", fdef.signature().name(), "' because an op with the same name already exists."); } - ptr.reset(new FunctionDefAndOpRegistration(fdef)); + entry->reset(new FunctionDefAndOpRegistration(fdef)); return Status::OK(); } Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { - if (func_grad_.count(grad.function_name()) > 0) { - return errors::InvalidArgument("Gradient for function '", - grad.function_name(), "' already exists."); + string* entry = &func_grad_[grad.function_name()]; + if (!entry->empty()) { + if (*entry != grad.gradient_func()) { + return errors::InvalidArgument( + "Cannot assign gradient function '", grad.gradient_func(), "' to '", + grad.function_name(), "' because it already has gradient function ", + "'", *entry, "'"); + } + // Ignore duplicate GradientDefs + return Status::OK(); } - func_grad_[grad.function_name()] = grad.gradient_func(); + *entry = grad.gradient_func(); return Status::OK(); } +// TODO(skyewm): don't modify FunctionLibraryDefinition in case of error Status FunctionLibraryDefinition::AddLibrary( const FunctionLibraryDefinition& other) { for (auto iter : other.function_defs_) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index b8d5b8797af..2342e08b383 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -287,20 +287,24 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const FunctionDef* Find(const string& func) const; // Adds function definition 'fdef' to this function library. - // Returns status 'ok' on success, or error otherwise. + // Returns status 'ok' on success, or error otherwise. This is a no-op if + // 'fdef' already exists in this function library. // If 'fdef' is successfully added to the library, it will be accessible // from 'LookUp' and included in the proto returned by 'ToProto'. Status AddFunctionDef(const FunctionDef& fdef); // Adds gradient definition 'grad' to this function library. + // This is a no-op if 'grad' already exists in this function library. // If 'grad' is successfully added, it will be accessible via 'FindGradient' // and included in the proto returned by 'ToProto'. Status AddGradientDef(const GradientDef& grad); // Adds the functions and gradients in 'other' to this function library. + // Duplicate functions and gradients are ignored. Status AddLibrary(const FunctionLibraryDefinition& other); // Adds the functions and gradients in 'lib_def' to this function library. + // Duplicate functions and gradients are ignored. Status AddLibrary(const FunctionDefLibrary& lib_def); // If the gradient function for 'func' is specified explicitly in diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 2ecdc36c111..1173384a1ec 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -971,6 +971,10 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { EXPECT_EQ(s.error_message(), "Cannot add function 'Add' because an op with the same name " "already exists."); + + // Already-added functions don't produce error + TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); + TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); } TEST(FunctionLibraryDefinitionTest, AddGradientDef) { @@ -984,12 +988,16 @@ TEST(FunctionLibraryDefinitionTest, AddGradientDef) { grad.set_gradient_func(test::function::XTimesFour().signature().name()); TF_EXPECT_OK(lib_def.AddGradientDef(grad)); + // Already-added gradients don't produce error + TF_EXPECT_OK(lib_def.AddGradientDef(grad)); + // Test that adding a duplicate gradient fails grad.set_gradient_func(test::function::XTimes16().signature().name()); Status s = lib_def.AddGradientDef(grad); EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); EXPECT_EQ(s.error_message(), - "Gradient for function 'XTimesTwo' already exists."); + "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " + "it already has gradient function 'XTimesFour'"); } TEST(FunctionLibraryDefinitionTest, AddLibrary) { @@ -998,35 +1006,46 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) { *proto.add_function() = test::function::XTimesTwo(); FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); - // Error if you try to add the same function twice - Status s = lib_def.AddLibrary(lib_def); - EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); - EXPECT_EQ(s.error_message(), - "Function with name: XTimesTwo already exists in function " - "library."); - // Add gradient GradientDef grad; grad.set_function_name(test::function::XTimesTwo().signature().name()); grad.set_gradient_func(test::function::XTimesFour().signature().name()); TF_EXPECT_OK(lib_def.AddGradientDef(grad)); - // Error if you try to add the same library function twice + // Error if you try to add conflicting function proto.Clear(); - *proto.add_gradient() = grad; + FunctionDef fdef = test::function::XTimesFour(); + fdef.mutable_signature()->set_name( + test::function::XTimesTwo().signature().name()); + *proto.add_function() = fdef; FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto); - s = lib_def.AddLibrary(lib_def2); + Status s = lib_def.AddLibrary(lib_def2); EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); EXPECT_EQ(s.error_message(), - "Gradient for function 'XTimesTwo' already exists."); + "Cannot add function 'XTimesTwo' because a different function with " + "the same name already exists."); + + // Error if you try to add conflicting gradient + proto.Clear(); + grad.set_gradient_func(test::function::XTimes16().signature().name()); + *proto.add_gradient() = grad; + FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); + s = lib_def.AddLibrary(lib_def3); + EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); + EXPECT_EQ(s.error_message(), + "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " + "it already has gradient function 'XTimesFour'"); // No conflicting functions or gradients OK proto.Clear(); *proto.add_function() = test::function::XTimesFour(); grad.set_function_name(test::function::XTimes16().signature().name()); *proto.add_gradient() = grad; - FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); - TF_EXPECT_OK(lib_def.AddLibrary(lib_def3)); + FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto); + TF_EXPECT_OK(lib_def.AddLibrary(lib_def4)); + + // OK to add the same functions and gradients twice + TF_EXPECT_OK(lib_def.AddLibrary(lib_def)); } TEST(FunctionLibraryDefinitionTest, ToProto) { diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 7a5e76fdd01..9691326c998 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -413,35 +413,7 @@ void Graph::RemoveEdge(const Edge* e) { } Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { - for (const FunctionDef& fdef : fdef_lib.function()) { - const FunctionDef* preexisting_fdef = ops_.Find(fdef.signature().name()); - if (preexisting_fdef != nullptr) { - if (!FunctionDefsEqual(*preexisting_fdef, fdef)) { - return errors::InvalidArgument( - "Cannot add function '", fdef.signature().name(), - "' because a different function with the same name already " - "exists."); - } - // Ignore duplicate FunctionDefs - continue; - } - TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef)); - } - for (const GradientDef& grad : fdef_lib.gradient()) { - string preexisting_grad_func = ops_.FindGradient(grad.function_name()); - if (!preexisting_grad_func.empty()) { - if (preexisting_grad_func != grad.gradient_func()) { - return errors::InvalidArgument( - "Cannot assign gradient function '", grad.gradient_func(), "' to '", - grad.function_name(), "' because it already has gradient function ", - "'", preexisting_grad_func, "'"); - } - // Ignore duplicate GradientDefs - continue; - } - TF_RETURN_IF_ERROR(ops_.AddGradientDef(grad)); - } - return Status::OK(); + return ops_.AddLibrary(fdef_lib); } namespace {