Update backwards-compatibility checks to allow adding a new default to an op attribute.

PiperOrigin-RevId: 272084965
This commit is contained in:
Edward Loper 2019-09-30 15:56:04 -07:00 committed by TensorFlower Gardener
parent 86902a8ada
commit 714f1dbd7a
5 changed files with 46 additions and 20 deletions

View File

@ -1051,7 +1051,7 @@ TEST_F(OpCompatibilityTest, RenameOutputListFails) {
"Output signature mismatch 'old:T' vs. 'new:T'");
}
// Should not be able to add a default to an attr.
// It's ok to add a default to an attr if it doesn't already have one.
REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234");
REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel);
@ -1064,9 +1064,8 @@ TEST_F(OpCompatibilityTest, AddDefault) {
TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def)
.Attr("a", 765)
.Finalize(node_def()));
ExpectDefaultChangeFailure(
old_op.op_def,
"Attr 'a' has added/removed it's default; from no default to 1234");
ExpectSuccess(old_op.op_def);
EXPECT_EQ("{{node add_default}} = AddDefault[a=765]()", Result());
}
// Should not be able to remove a default from an attr.
@ -1083,7 +1082,7 @@ TEST_F(OpCompatibilityTest, RemoveDefault) {
NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def()));
ExpectDefaultChangeFailure(
old_op.op_def,
"Attr 'a' has added/removed it's default; from 91 to no default");
"Attr 'a' has removed it's default; from 91 to no default");
}
// Should not be able to change a default for an attr.

View File

@ -734,10 +734,13 @@ Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
const OpDef::AttrDef* new_attr =
gtl::FindPtrOrNull(new_attrs, old_attr.name());
if (new_attr == nullptr) continue;
if (old_attr.has_default_value() != new_attr->has_default_value()) {
if (new_attr->has_default_value() && !old_attr.has_default_value()) {
continue; // Adding new default values is safe.
}
if (old_attr.has_default_value() && !new_attr->has_default_value()) {
return errors::InvalidArgument(
"Attr '", old_attr.name(), "' has added/removed it's default; ",
"from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
"Attr '", old_attr.name(), "' has removed it's default; ", "from ",
DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
}
if (old_attr.has_default_value() &&
!AreAttrValuesEqual(old_attr.default_value(),

View File

@ -68,8 +68,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
const OpDef& penultimate_op,
const OpDef& new_op);
// Returns an error if the default value for any attr is added/removed/modified
// in new_op compared to old_op.
// Returns an error if the default value for any attr is removed or modified
// in new_op compared to old_op. Adding new default values is safe, and does
// not raise an error.
Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op);
// Remove all docs from *op_def / *op_list.

View File

@ -53,17 +53,19 @@ class ValidateOpDefTest : public ::testing::Test {
return ValidateOpDef(op_reg_data.op_def);
}
}
void ExpectFailure(const Status& status, const string& message) {
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
if (!status.ok()) {
LOG(INFO) << "message: " << status;
EXPECT_TRUE(absl::StrContains(status.ToString(), message))
<< "Actual: " << status << "\nExpected to contain: " << message;
}
}
};
namespace {
void ExpectFailure(const Status& status, const string& message) {
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
if (!status.ok()) {
LOG(INFO) << "message: " << status;
EXPECT_TRUE(absl::StrContains(status.ToString(), message))
<< "Actual: " << status << "\nExpected to contain: " << message;
}
}
} // namespace
TEST_F(ValidateOpDefTest, OpDefValid) {
TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int")));
TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32")));
@ -527,5 +529,26 @@ TEST(OpDefEqualityTest, EqualAndHash) {
ExpectDifferent(o1, o3);
}
TEST(OpDefAttrDefaultsUnchangedTest, Foo) {
const auto& op1 = FromText("name: 'op1' attr { name: 'n' type: 'string'}");
const auto& op2 = FromText(
"name: 'op2' attr { name: 'n' type: 'string' default_value: {s: 'x'}}");
const auto& op3 = FromText(
"name: 'op3' attr { name: 'n' type: 'string' default_value: {s: 'y'}}");
// Adding a default value: fine.
TF_EXPECT_OK(OpDefAttrDefaultsUnchanged(op1, op2));
// Changing a default value: not ok.
Status changed_attr = OpDefAttrDefaultsUnchanged(op2, op3);
ExpectFailure(changed_attr,
"Attr 'n' has changed it's default value; from \"x\" to \"y\"");
// Removing a default value: not ok.
Status removed_attr = OpDefAttrDefaultsUnchanged(op2, op1);
ExpectFailure(removed_attr,
"Attr 'n' has removed it's default; from \"x\" to no default");
}
} // namespace
} // namespace tensorflow

View File

@ -198,7 +198,7 @@ Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op));
}
// Verify default value of attrs has not been added/removed/modified
// Verify default value of attrs has not been removed or modified
// as compared to only the last historical version.
TF_RETURN_IF_ERROR(
OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));