Update backwards-compatibility checks to allow adding a new default to an op attribute.

PiperOrigin-RevId: 272084965
This commit is contained in:
Edward Loper 2019-09-30 15:56:04 -07:00 committed by TensorFlower Gardener
parent 86902a8ada
commit 714f1dbd7a
5 changed files with 46 additions and 20 deletions

View File

@ -1051,7 +1051,7 @@ TEST_F(OpCompatibilityTest, RenameOutputListFails) {
"Output signature mismatch 'old:T' vs. 'new:T'"); "Output signature mismatch 'old:T' vs. 'new:T'");
} }
// Should not be able to add a default to an attr. // It's ok to add a default to an attr if it doesn't already have one.
REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234"); REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234");
REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel); REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel);
@ -1064,9 +1064,8 @@ TEST_F(OpCompatibilityTest, AddDefault) {
TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def) TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def)
.Attr("a", 765) .Attr("a", 765)
.Finalize(node_def())); .Finalize(node_def()));
ExpectDefaultChangeFailure( ExpectSuccess(old_op.op_def);
old_op.op_def, EXPECT_EQ("{{node add_default}} = AddDefault[a=765]()", Result());
"Attr 'a' has added/removed it's default; from no default to 1234");
} }
// Should not be able to remove a default from an attr. // Should not be able to remove a default from an attr.
@ -1083,7 +1082,7 @@ TEST_F(OpCompatibilityTest, RemoveDefault) {
NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def())); NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def()));
ExpectDefaultChangeFailure( ExpectDefaultChangeFailure(
old_op.op_def, old_op.op_def,
"Attr 'a' has added/removed it's default; from 91 to no default"); "Attr 'a' has removed it's default; from 91 to no default");
} }
// Should not be able to change a default for an attr. // Should not be able to change a default for an attr.

View File

@ -734,10 +734,13 @@ Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
const OpDef::AttrDef* new_attr = const OpDef::AttrDef* new_attr =
gtl::FindPtrOrNull(new_attrs, old_attr.name()); gtl::FindPtrOrNull(new_attrs, old_attr.name());
if (new_attr == nullptr) continue; if (new_attr == nullptr) continue;
if (old_attr.has_default_value() != new_attr->has_default_value()) { if (new_attr->has_default_value() && !old_attr.has_default_value()) {
continue; // Adding new default values is safe.
}
if (old_attr.has_default_value() && !new_attr->has_default_value()) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Attr '", old_attr.name(), "' has added/removed it's default; ", "Attr '", old_attr.name(), "' has removed it's default; ", "from ",
"from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr)); DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
} }
if (old_attr.has_default_value() && if (old_attr.has_default_value() &&
!AreAttrValuesEqual(old_attr.default_value(), !AreAttrValuesEqual(old_attr.default_value(),

View File

@ -68,8 +68,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
const OpDef& penultimate_op, const OpDef& penultimate_op,
const OpDef& new_op); const OpDef& new_op);
// Returns an error if the default value for any attr is added/removed/modified // Returns an error if the default value for any attr is removed or modified
// in new_op compared to old_op. // in new_op compared to old_op. Adding new default values is safe, and does
// not raise an error.
Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op); Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op);
// Remove all docs from *op_def / *op_list. // Remove all docs from *op_def / *op_list.

View File

@ -53,7 +53,9 @@ class ValidateOpDefTest : public ::testing::Test {
return ValidateOpDef(op_reg_data.op_def); return ValidateOpDef(op_reg_data.op_def);
} }
} }
};
namespace {
void ExpectFailure(const Status& status, const string& message) { void ExpectFailure(const Status& status, const string& message) {
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
if (!status.ok()) { if (!status.ok()) {
@ -62,7 +64,7 @@ class ValidateOpDefTest : public ::testing::Test {
<< "Actual: " << status << "\nExpected to contain: " << message; << "Actual: " << status << "\nExpected to contain: " << message;
} }
} }
}; } // namespace
TEST_F(ValidateOpDefTest, OpDefValid) { TEST_F(ValidateOpDefTest, OpDefValid) {
TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int"))); TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int")));
@ -527,5 +529,26 @@ TEST(OpDefEqualityTest, EqualAndHash) {
ExpectDifferent(o1, o3); ExpectDifferent(o1, o3);
} }
TEST(OpDefAttrDefaultsUnchangedTest, Foo) {
const auto& op1 = FromText("name: 'op1' attr { name: 'n' type: 'string'}");
const auto& op2 = FromText(
"name: 'op2' attr { name: 'n' type: 'string' default_value: {s: 'x'}}");
const auto& op3 = FromText(
"name: 'op3' attr { name: 'n' type: 'string' default_value: {s: 'y'}}");
// Adding a default value: fine.
TF_EXPECT_OK(OpDefAttrDefaultsUnchanged(op1, op2));
// Changing a default value: not ok.
Status changed_attr = OpDefAttrDefaultsUnchanged(op2, op3);
ExpectFailure(changed_attr,
"Attr 'n' has changed it's default value; from \"x\" to \"y\"");
// Removing a default value: not ok.
Status removed_attr = OpDefAttrDefaultsUnchanged(op2, op1);
ExpectFailure(removed_attr,
"Attr 'n' has removed it's default; from \"x\" to no default");
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -198,7 +198,7 @@ Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op)); TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op));
} }
// Verify default value of attrs has not been added/removed/modified // Verify default value of attrs has not been removed or modified
// as compared to only the last historical version. // as compared to only the last historical version.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op)); OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));