Move duplicate detection logic from Graph to FunctionLibraryDefinition
Turns out this is more useful, since there are many function libraries that don't belong to a graph. This will be used in a future change. Note that this maintains the current behavior of Graph. In addition, updates FunctionDefsEqual() to handle unset attr entries (I ran into this when using this in said future change). PiperOrigin-RevId: 161126628
This commit is contained in:
parent
2caec3af18
commit
7d5c74a9c8
@ -724,6 +724,23 @@ string DebugStringWhole(const GraphDef& gdef) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the name -> attr mapping of fdef's attrs that have a value set. In
|
||||
// Python, it's possible to access unset attrs, which returns a default value
|
||||
// and adds an unset attr to the map.
|
||||
std::map<StringPiece, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
|
||||
std::map<StringPiece, AttrValue> set_attrs;
|
||||
for (auto iter : fdef.attr()) {
|
||||
if (iter.second.value_case() != AttrValue::VALUE_NOT_SET) {
|
||||
set_attrs[iter.first] = iter.second;
|
||||
}
|
||||
}
|
||||
return set_attrs;
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
||||
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
|
||||
// NOTE(skyewm): Using MessageDifferencer would be better here, but that is
|
||||
// currently not included in tensorflow/core/platform/default/protobuf.h, so
|
||||
@ -736,10 +753,12 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
|
||||
f2.signature().SerializeToString(&sig2);
|
||||
if (sig1 != sig2) return false;
|
||||
|
||||
if (f1.attr().size() != f2.attr().size()) return false;
|
||||
for (auto iter1 : f1.attr()) {
|
||||
auto iter2 = f2.attr().find(iter1.first);
|
||||
if (iter2 == f2.attr().end()) return false;
|
||||
std::map<StringPiece, AttrValue> f1_attrs = GetSetAttrs(f1);
|
||||
std::map<StringPiece, AttrValue> f2_attrs = GetSetAttrs(f2);
|
||||
if (f1_attrs.size() != f2_attrs.size()) return false;
|
||||
for (auto iter1 : f1_attrs) {
|
||||
auto iter2 = f2_attrs.find(iter1.first);
|
||||
if (iter2 == f2_attrs.end()) return false;
|
||||
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
|
||||
}
|
||||
|
||||
@ -883,11 +902,17 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
|
||||
}
|
||||
|
||||
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
|
||||
auto& ptr = function_defs_[fdef.signature().name()];
|
||||
if (ptr != nullptr) {
|
||||
return errors::InvalidArgument("Function with name: ",
|
||||
fdef.signature().name(),
|
||||
" already exists in function library.");
|
||||
std::unique_ptr<FunctionDefAndOpRegistration>* entry =
|
||||
&function_defs_[fdef.signature().name()];
|
||||
if (*entry != nullptr) {
|
||||
if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add function '", fdef.signature().name(),
|
||||
"' because a different function with the same name already "
|
||||
"exists.");
|
||||
}
|
||||
// Ignore duplicate FunctionDefs
|
||||
return Status::OK();
|
||||
}
|
||||
const OpDef* op_def;
|
||||
if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
|
||||
@ -895,19 +920,27 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
|
||||
"Cannot add function '", fdef.signature().name(),
|
||||
"' because an op with the same name already exists.");
|
||||
}
|
||||
ptr.reset(new FunctionDefAndOpRegistration(fdef));
|
||||
entry->reset(new FunctionDefAndOpRegistration(fdef));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
|
||||
if (func_grad_.count(grad.function_name()) > 0) {
|
||||
return errors::InvalidArgument("Gradient for function '",
|
||||
grad.function_name(), "' already exists.");
|
||||
string* entry = &func_grad_[grad.function_name()];
|
||||
if (!entry->empty()) {
|
||||
if (*entry != grad.gradient_func()) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot assign gradient function '", grad.gradient_func(), "' to '",
|
||||
grad.function_name(), "' because it already has gradient function ",
|
||||
"'", *entry, "'");
|
||||
}
|
||||
// Ignore duplicate GradientDefs
|
||||
return Status::OK();
|
||||
}
|
||||
func_grad_[grad.function_name()] = grad.gradient_func();
|
||||
*entry = grad.gradient_func();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(skyewm): don't modify FunctionLibraryDefinition in case of error
|
||||
Status FunctionLibraryDefinition::AddLibrary(
|
||||
const FunctionLibraryDefinition& other) {
|
||||
for (auto iter : other.function_defs_) {
|
||||
|
@ -287,20 +287,24 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
const FunctionDef* Find(const string& func) const;
|
||||
|
||||
// Adds function definition 'fdef' to this function library.
|
||||
// Returns status 'ok' on success, or error otherwise.
|
||||
// Returns status 'ok' on success, or error otherwise. This is a no-op if
|
||||
// 'fdef' already exists in this function library.
|
||||
// If 'fdef' is successfully added to the library, it will be accessible
|
||||
// from 'LookUp' and included in the proto returned by 'ToProto'.
|
||||
Status AddFunctionDef(const FunctionDef& fdef);
|
||||
|
||||
// Adds gradient definition 'grad' to this function library.
|
||||
// This is a no-op if 'grad' already exists in this function library.
|
||||
// If 'grad' is successfully added, it will be accessible via 'FindGradient'
|
||||
// and included in the proto returned by 'ToProto'.
|
||||
Status AddGradientDef(const GradientDef& grad);
|
||||
|
||||
// Adds the functions and gradients in 'other' to this function library.
|
||||
// Duplicate functions and gradients are ignored.
|
||||
Status AddLibrary(const FunctionLibraryDefinition& other);
|
||||
|
||||
// Adds the functions and gradients in 'lib_def' to this function library.
|
||||
// Duplicate functions and gradients are ignored.
|
||||
Status AddLibrary(const FunctionDefLibrary& lib_def);
|
||||
|
||||
// If the gradient function for 'func' is specified explicitly in
|
||||
|
@ -971,6 +971,10 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"Cannot add function 'Add' because an op with the same name "
|
||||
"already exists.");
|
||||
|
||||
// Already-added functions don't produce error
|
||||
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
|
||||
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
|
||||
}
|
||||
|
||||
TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
|
||||
@ -984,12 +988,16 @@ TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
|
||||
grad.set_gradient_func(test::function::XTimesFour().signature().name());
|
||||
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
|
||||
|
||||
// Already-added gradients don't produce error
|
||||
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
|
||||
|
||||
// Test that adding a duplicate gradient fails
|
||||
grad.set_gradient_func(test::function::XTimes16().signature().name());
|
||||
Status s = lib_def.AddGradientDef(grad);
|
||||
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"Gradient for function 'XTimesTwo' already exists.");
|
||||
"Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
|
||||
"it already has gradient function 'XTimesFour'");
|
||||
}
|
||||
|
||||
TEST(FunctionLibraryDefinitionTest, AddLibrary) {
|
||||
@ -998,35 +1006,46 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) {
|
||||
*proto.add_function() = test::function::XTimesTwo();
|
||||
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
|
||||
|
||||
// Error if you try to add the same function twice
|
||||
Status s = lib_def.AddLibrary(lib_def);
|
||||
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"Function with name: XTimesTwo already exists in function "
|
||||
"library.");
|
||||
|
||||
// Add gradient
|
||||
GradientDef grad;
|
||||
grad.set_function_name(test::function::XTimesTwo().signature().name());
|
||||
grad.set_gradient_func(test::function::XTimesFour().signature().name());
|
||||
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
|
||||
|
||||
// Error if you try to add the same library function twice
|
||||
// Error if you try to add conflicting function
|
||||
proto.Clear();
|
||||
*proto.add_gradient() = grad;
|
||||
FunctionDef fdef = test::function::XTimesFour();
|
||||
fdef.mutable_signature()->set_name(
|
||||
test::function::XTimesTwo().signature().name());
|
||||
*proto.add_function() = fdef;
|
||||
FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto);
|
||||
s = lib_def.AddLibrary(lib_def2);
|
||||
Status s = lib_def.AddLibrary(lib_def2);
|
||||
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"Gradient for function 'XTimesTwo' already exists.");
|
||||
"Cannot add function 'XTimesTwo' because a different function with "
|
||||
"the same name already exists.");
|
||||
|
||||
// Error if you try to add conflicting gradient
|
||||
proto.Clear();
|
||||
grad.set_gradient_func(test::function::XTimes16().signature().name());
|
||||
*proto.add_gradient() = grad;
|
||||
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
|
||||
s = lib_def.AddLibrary(lib_def3);
|
||||
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
|
||||
"it already has gradient function 'XTimesFour'");
|
||||
|
||||
// No conflicting functions or gradients OK
|
||||
proto.Clear();
|
||||
*proto.add_function() = test::function::XTimesFour();
|
||||
grad.set_function_name(test::function::XTimes16().signature().name());
|
||||
*proto.add_gradient() = grad;
|
||||
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
|
||||
TF_EXPECT_OK(lib_def.AddLibrary(lib_def3));
|
||||
FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto);
|
||||
TF_EXPECT_OK(lib_def.AddLibrary(lib_def4));
|
||||
|
||||
// OK to add the same functions and gradients twice
|
||||
TF_EXPECT_OK(lib_def.AddLibrary(lib_def));
|
||||
}
|
||||
|
||||
TEST(FunctionLibraryDefinitionTest, ToProto) {
|
||||
|
@ -413,35 +413,7 @@ void Graph::RemoveEdge(const Edge* e) {
|
||||
}
|
||||
|
||||
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
|
||||
for (const FunctionDef& fdef : fdef_lib.function()) {
|
||||
const FunctionDef* preexisting_fdef = ops_.Find(fdef.signature().name());
|
||||
if (preexisting_fdef != nullptr) {
|
||||
if (!FunctionDefsEqual(*preexisting_fdef, fdef)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add function '", fdef.signature().name(),
|
||||
"' because a different function with the same name already "
|
||||
"exists.");
|
||||
}
|
||||
// Ignore duplicate FunctionDefs
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef));
|
||||
}
|
||||
for (const GradientDef& grad : fdef_lib.gradient()) {
|
||||
string preexisting_grad_func = ops_.FindGradient(grad.function_name());
|
||||
if (!preexisting_grad_func.empty()) {
|
||||
if (preexisting_grad_func != grad.gradient_func()) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot assign gradient function '", grad.gradient_func(), "' to '",
|
||||
grad.function_name(), "' because it already has gradient function ",
|
||||
"'", preexisting_grad_func, "'");
|
||||
}
|
||||
// Ignore duplicate GradientDefs
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ops_.AddGradientDef(grad));
|
||||
}
|
||||
return Status::OK();
|
||||
return ops_.AddLibrary(fdef_lib);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
Loading…
Reference in New Issue
Block a user