Move duplicate detection logic from Graph to FunctionLibraryDefinition

Turns out this is more useful, since there are many function libraries
that don't belong to a graph. This will be used in a future
change. Note that this maintains the current behavior of Graph.

In addition, updates FunctionDefsEqual() to handle unset attr entries
(I ran into this when using this in said future change).

PiperOrigin-RevId: 161126628
This commit is contained in:
Skye Wanderman-Milne 2017-07-06 13:59:28 -07:00 committed by TensorFlower Gardener
parent 2caec3af18
commit 7d5c74a9c8
4 changed files with 86 additions and 58 deletions

View File

@ -724,6 +724,23 @@ string DebugStringWhole(const GraphDef& gdef) {
return ret;
}
namespace {
// Returns the name -> attr mapping of fdef's attrs that have a value set. In
// Python, it's possible to access unset attrs, which returns a default value
// and adds an unset attr to the map.
std::map<StringPiece, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
std::map<StringPiece, AttrValue> set_attrs;
for (auto iter : fdef.attr()) {
if (iter.second.value_case() != AttrValue::VALUE_NOT_SET) {
set_attrs[iter.first] = iter.second;
}
}
return set_attrs;
}
} // end namespace
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
// NOTE(skyewm): Using MessageDifferencer would be better here, but that is
// currently not included in tensorflow/core/platform/default/protobuf.h, so
@ -736,10 +753,12 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
f2.signature().SerializeToString(&sig2);
if (sig1 != sig2) return false;
if (f1.attr().size() != f2.attr().size()) return false;
for (auto iter1 : f1.attr()) {
auto iter2 = f2.attr().find(iter1.first);
if (iter2 == f2.attr().end()) return false;
std::map<StringPiece, AttrValue> f1_attrs = GetSetAttrs(f1);
std::map<StringPiece, AttrValue> f2_attrs = GetSetAttrs(f2);
if (f1_attrs.size() != f2_attrs.size()) return false;
for (auto iter1 : f1_attrs) {
auto iter2 = f2_attrs.find(iter1.first);
if (iter2 == f2_attrs.end()) return false;
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
}
@ -883,11 +902,17 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
}
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
auto& ptr = function_defs_[fdef.signature().name()];
if (ptr != nullptr) {
return errors::InvalidArgument("Function with name: ",
fdef.signature().name(),
" already exists in function library.");
std::unique_ptr<FunctionDefAndOpRegistration>* entry =
&function_defs_[fdef.signature().name()];
if (*entry != nullptr) {
if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
return errors::InvalidArgument(
"Cannot add function '", fdef.signature().name(),
"' because a different function with the same name already "
"exists.");
}
// Ignore duplicate FunctionDefs
return Status::OK();
}
const OpDef* op_def;
if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
@ -895,19 +920,27 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
"Cannot add function '", fdef.signature().name(),
"' because an op with the same name already exists.");
}
ptr.reset(new FunctionDefAndOpRegistration(fdef));
entry->reset(new FunctionDefAndOpRegistration(fdef));
return Status::OK();
}
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
if (func_grad_.count(grad.function_name()) > 0) {
return errors::InvalidArgument("Gradient for function '",
grad.function_name(), "' already exists.");
string* entry = &func_grad_[grad.function_name()];
if (!entry->empty()) {
if (*entry != grad.gradient_func()) {
return errors::InvalidArgument(
"Cannot assign gradient function '", grad.gradient_func(), "' to '",
grad.function_name(), "' because it already has gradient function ",
"'", *entry, "'");
}
// Ignore duplicate GradientDefs
return Status::OK();
}
func_grad_[grad.function_name()] = grad.gradient_func();
*entry = grad.gradient_func();
return Status::OK();
}
// TODO(skyewm): don't modify FunctionLibraryDefinition in case of error
Status FunctionLibraryDefinition::AddLibrary(
const FunctionLibraryDefinition& other) {
for (auto iter : other.function_defs_) {

View File

@ -287,20 +287,24 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const FunctionDef* Find(const string& func) const;
// Adds function definition 'fdef' to this function library.
// Returns status 'ok' on success, or error otherwise.
// Returns status 'ok' on success, or error otherwise. This is a no-op if
// 'fdef' already exists in this function library.
// If 'fdef' is successfully added to the library, it will be accessible
// from 'LookUp' and included in the proto returned by 'ToProto'.
Status AddFunctionDef(const FunctionDef& fdef);
// Adds gradient definition 'grad' to this function library.
// This is a no-op if 'grad' already exists in this function library.
// If 'grad' is successfully added, it will be accessible via 'FindGradient'
// and included in the proto returned by 'ToProto'.
Status AddGradientDef(const GradientDef& grad);
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
Status AddLibrary(const FunctionLibraryDefinition& other);
// Adds the functions and gradients in 'lib_def' to this function library.
// Duplicate functions and gradients are ignored.
Status AddLibrary(const FunctionDefLibrary& lib_def);
// If the gradient function for 'func' is specified explicitly in

View File

@ -971,6 +971,10 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
EXPECT_EQ(s.error_message(),
"Cannot add function 'Add' because an op with the same name "
"already exists.");
// Already-added functions don't produce error
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
}
TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
@ -984,12 +988,16 @@ TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
grad.set_gradient_func(test::function::XTimesFour().signature().name());
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
// Already-added gradients don't produce error
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
// Test that adding a duplicate gradient fails
grad.set_gradient_func(test::function::XTimes16().signature().name());
Status s = lib_def.AddGradientDef(grad);
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.error_message(),
"Gradient for function 'XTimesTwo' already exists.");
"Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
"it already has gradient function 'XTimesFour'");
}
TEST(FunctionLibraryDefinitionTest, AddLibrary) {
@ -998,35 +1006,46 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) {
*proto.add_function() = test::function::XTimesTwo();
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
// Error if you try to add the same function twice
Status s = lib_def.AddLibrary(lib_def);
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.error_message(),
"Function with name: XTimesTwo already exists in function "
"library.");
// Add gradient
GradientDef grad;
grad.set_function_name(test::function::XTimesTwo().signature().name());
grad.set_gradient_func(test::function::XTimesFour().signature().name());
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
// Error if you try to add the same library function twice
// Error if you try to add conflicting function
proto.Clear();
*proto.add_gradient() = grad;
FunctionDef fdef = test::function::XTimesFour();
fdef.mutable_signature()->set_name(
test::function::XTimesTwo().signature().name());
*proto.add_function() = fdef;
FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto);
s = lib_def.AddLibrary(lib_def2);
Status s = lib_def.AddLibrary(lib_def2);
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.error_message(),
"Gradient for function 'XTimesTwo' already exists.");
"Cannot add function 'XTimesTwo' because a different function with "
"the same name already exists.");
// Error if you try to add conflicting gradient
proto.Clear();
grad.set_gradient_func(test::function::XTimes16().signature().name());
*proto.add_gradient() = grad;
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
s = lib_def.AddLibrary(lib_def3);
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.error_message(),
"Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
"it already has gradient function 'XTimesFour'");
// No conflicting functions or gradients OK
proto.Clear();
*proto.add_function() = test::function::XTimesFour();
grad.set_function_name(test::function::XTimes16().signature().name());
*proto.add_gradient() = grad;
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
TF_EXPECT_OK(lib_def.AddLibrary(lib_def3));
FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto);
TF_EXPECT_OK(lib_def.AddLibrary(lib_def4));
// OK to add the same functions and gradients twice
TF_EXPECT_OK(lib_def.AddLibrary(lib_def));
}
TEST(FunctionLibraryDefinitionTest, ToProto) {

View File

@ -413,35 +413,7 @@ void Graph::RemoveEdge(const Edge* e) {
}
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
for (const FunctionDef& fdef : fdef_lib.function()) {
const FunctionDef* preexisting_fdef = ops_.Find(fdef.signature().name());
if (preexisting_fdef != nullptr) {
if (!FunctionDefsEqual(*preexisting_fdef, fdef)) {
return errors::InvalidArgument(
"Cannot add function '", fdef.signature().name(),
"' because a different function with the same name already "
"exists.");
}
// Ignore duplicate FunctionDefs
continue;
}
TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef));
}
for (const GradientDef& grad : fdef_lib.gradient()) {
string preexisting_grad_func = ops_.FindGradient(grad.function_name());
if (!preexisting_grad_func.empty()) {
if (preexisting_grad_func != grad.gradient_func()) {
return errors::InvalidArgument(
"Cannot assign gradient function '", grad.gradient_func(), "' to '",
grad.function_name(), "' because it already has gradient function ",
"'", preexisting_grad_func, "'");
}
// Ignore duplicate GradientDefs
continue;
}
TF_RETURN_IF_ERROR(ops_.AddGradientDef(grad));
}
return Status::OK();
return ops_.AddLibrary(fdef_lib);
}
namespace {