diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc index 4edb60786d7..16fa9dbe820 100644 --- a/tensorflow/core/framework/op_compatibility_test.cc +++ b/tensorflow/core/framework/op_compatibility_test.cc @@ -1051,7 +1051,7 @@ TEST_F(OpCompatibilityTest, RenameOutputListFails) { "Output signature mismatch 'old:T' vs. 'new:T'"); } -// Should not be able to add a default to an attr. +// It's ok to add a default to an attr if it doesn't already have one. REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234"); REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel); @@ -1064,9 +1064,8 @@ TEST_F(OpCompatibilityTest, AddDefault) { TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def) .Attr("a", 765) .Finalize(node_def())); - ExpectDefaultChangeFailure( - old_op.op_def, - "Attr 'a' has added/removed it's default; from no default to 1234"); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("{{node add_default}} = AddDefault[a=765]()", Result()); } // Should not be able to remove a default from an attr. @@ -1083,7 +1082,7 @@ TEST_F(OpCompatibilityTest, RemoveDefault) { NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def())); ExpectDefaultChangeFailure( old_op.op_def, - "Attr 'a' has added/removed it's default; from 91 to no default"); + "Attr 'a' has removed it's default; from 91 to no default"); } // Should not be able to change a default for an attr. diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index 94401d4a6c5..0ebc4bf2483 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -734,10 +734,13 @@ Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { const OpDef::AttrDef* new_attr = gtl::FindPtrOrNull(new_attrs, old_attr.name()); if (new_attr == nullptr) continue; - if (old_attr.has_default_value() != new_attr->has_default_value()) { + if (new_attr->has_default_value() && !old_attr.has_default_value()) { + continue; // Adding new default values is safe. + } + if (old_attr.has_default_value() && !new_attr->has_default_value()) { return errors::InvalidArgument( - "Attr '", old_attr.name(), "' has added/removed it's default; ", - "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr)); + "Attr '", old_attr.name(), "' has removed it's default; ", "from ", + DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr)); } if (old_attr.has_default_value() && !AreAttrValuesEqual(old_attr.default_value(), diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index 85afe2bdea0..311e40afeea 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -68,8 +68,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, const OpDef& penultimate_op, const OpDef& new_op); -// Returns an error if the default value for any attr is added/removed/modified -// in new_op compared to old_op. +// Returns an error if the default value for any attr is removed or modified +// in new_op compared to old_op. Adding new default values is safe, and does +// not raise an error. Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op); // Remove all docs from *op_def / *op_list. diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc index e6afbc81c44..90a3f6720e1 100644 --- a/tensorflow/core/framework/op_def_util_test.cc +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -53,17 +53,19 @@ class ValidateOpDefTest : public ::testing::Test { return ValidateOpDef(op_reg_data.op_def); } } - - void ExpectFailure(const Status& status, const string& message) { - EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; - if (!status.ok()) { - LOG(INFO) << "message: " << status; - EXPECT_TRUE(absl::StrContains(status.ToString(), message)) - << "Actual: " << status << "\nExpected to contain: " << message; - } - } }; +namespace { +void ExpectFailure(const Status& status, const string& message) { + EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; + if (!status.ok()) { + LOG(INFO) << "message: " << status; + EXPECT_TRUE(absl::StrContains(status.ToString(), message)) + << "Actual: " << status << "\nExpected to contain: " << message; + } +} +} // namespace + TEST_F(ValidateOpDefTest, OpDefValid) { TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int"))); TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32"))); @@ -527,5 +529,26 @@ TEST(OpDefEqualityTest, EqualAndHash) { ExpectDifferent(o1, o3); } +TEST(OpDefAttrDefaultsUnchangedTest, Foo) { + const auto& op1 = FromText("name: 'op1' attr { name: 'n' type: 'string'}"); + const auto& op2 = FromText( + "name: 'op2' attr { name: 'n' type: 'string' default_value: {s: 'x'}}"); + const auto& op3 = FromText( + "name: 'op3' attr { name: 'n' type: 'string' default_value: {s: 'y'}}"); + + // Adding a default value: fine. + TF_EXPECT_OK(OpDefAttrDefaultsUnchanged(op1, op2)); + + // Changing a default value: not ok. + Status changed_attr = OpDefAttrDefaultsUnchanged(op2, op3); + ExpectFailure(changed_attr, + "Attr 'n' has changed it's default value; from \"x\" to \"y\""); + + // Removing a default value: not ok. + Status removed_attr = OpDefAttrDefaultsUnchanged(op2, op1); + ExpectFailure(removed_attr, + "Attr 'n' has removed it's default; from \"x\" to no default"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/op_compatibility_lib.cc b/tensorflow/core/ops/compat/op_compatibility_lib.cc index 9005e743c03..ccd4ba2dc12 100644 --- a/tensorflow/core/ops/compat/op_compatibility_lib.cc +++ b/tensorflow/core/ops/compat/op_compatibility_lib.cc @@ -198,7 +198,7 @@ Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops, TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op)); } - // Verify default value of attrs has not been added/removed/modified + // Verify default value of attrs has not been removed or modified // as compared to only the last historical version. TF_RETURN_IF_ERROR( OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));