Update backwards-compatibility checks to allow adding a new default to an op attribute.
PiperOrigin-RevId: 272084965
This commit is contained in:
parent
86902a8ada
commit
714f1dbd7a
@ -1051,7 +1051,7 @@ TEST_F(OpCompatibilityTest, RenameOutputListFails) {
|
|||||||
"Output signature mismatch 'old:T' vs. 'new:T'");
|
"Output signature mismatch 'old:T' vs. 'new:T'");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should not be able to add a default to an attr.
|
// It's ok to add a default to an attr if it doesn't already have one.
|
||||||
REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234");
|
REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234");
|
||||||
REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel);
|
REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel);
|
||||||
|
|
||||||
@ -1064,9 +1064,8 @@ TEST_F(OpCompatibilityTest, AddDefault) {
|
|||||||
TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def)
|
TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def)
|
||||||
.Attr("a", 765)
|
.Attr("a", 765)
|
||||||
.Finalize(node_def()));
|
.Finalize(node_def()));
|
||||||
ExpectDefaultChangeFailure(
|
ExpectSuccess(old_op.op_def);
|
||||||
old_op.op_def,
|
EXPECT_EQ("{{node add_default}} = AddDefault[a=765]()", Result());
|
||||||
"Attr 'a' has added/removed it's default; from no default to 1234");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should not be able to remove a default from an attr.
|
// Should not be able to remove a default from an attr.
|
||||||
@ -1083,7 +1082,7 @@ TEST_F(OpCompatibilityTest, RemoveDefault) {
|
|||||||
NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def()));
|
NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def()));
|
||||||
ExpectDefaultChangeFailure(
|
ExpectDefaultChangeFailure(
|
||||||
old_op.op_def,
|
old_op.op_def,
|
||||||
"Attr 'a' has added/removed it's default; from 91 to no default");
|
"Attr 'a' has removed it's default; from 91 to no default");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should not be able to change a default for an attr.
|
// Should not be able to change a default for an attr.
|
||||||
|
@ -734,10 +734,13 @@ Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
|
|||||||
const OpDef::AttrDef* new_attr =
|
const OpDef::AttrDef* new_attr =
|
||||||
gtl::FindPtrOrNull(new_attrs, old_attr.name());
|
gtl::FindPtrOrNull(new_attrs, old_attr.name());
|
||||||
if (new_attr == nullptr) continue;
|
if (new_attr == nullptr) continue;
|
||||||
if (old_attr.has_default_value() != new_attr->has_default_value()) {
|
if (new_attr->has_default_value() && !old_attr.has_default_value()) {
|
||||||
|
continue; // Adding new default values is safe.
|
||||||
|
}
|
||||||
|
if (old_attr.has_default_value() && !new_attr->has_default_value()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Attr '", old_attr.name(), "' has added/removed it's default; ",
|
"Attr '", old_attr.name(), "' has removed it's default; ", "from ",
|
||||||
"from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
|
DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
|
||||||
}
|
}
|
||||||
if (old_attr.has_default_value() &&
|
if (old_attr.has_default_value() &&
|
||||||
!AreAttrValuesEqual(old_attr.default_value(),
|
!AreAttrValuesEqual(old_attr.default_value(),
|
||||||
|
@ -68,8 +68,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
|
|||||||
const OpDef& penultimate_op,
|
const OpDef& penultimate_op,
|
||||||
const OpDef& new_op);
|
const OpDef& new_op);
|
||||||
|
|
||||||
// Returns an error if the default value for any attr is added/removed/modified
|
// Returns an error if the default value for any attr is removed or modified
|
||||||
// in new_op compared to old_op.
|
// in new_op compared to old_op. Adding new default values is safe, and does
|
||||||
|
// not raise an error.
|
||||||
Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op);
|
Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op);
|
||||||
|
|
||||||
// Remove all docs from *op_def / *op_list.
|
// Remove all docs from *op_def / *op_list.
|
||||||
|
@ -53,7 +53,9 @@ class ValidateOpDefTest : public ::testing::Test {
|
|||||||
return ValidateOpDef(op_reg_data.op_def);
|
return ValidateOpDef(op_reg_data.op_def);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
void ExpectFailure(const Status& status, const string& message) {
|
void ExpectFailure(const Status& status, const string& message) {
|
||||||
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
|
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
@ -62,7 +64,7 @@ class ValidateOpDefTest : public ::testing::Test {
|
|||||||
<< "Actual: " << status << "\nExpected to contain: " << message;
|
<< "Actual: " << status << "\nExpected to contain: " << message;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
} // namespace
|
||||||
|
|
||||||
TEST_F(ValidateOpDefTest, OpDefValid) {
|
TEST_F(ValidateOpDefTest, OpDefValid) {
|
||||||
TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int")));
|
TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int")));
|
||||||
@ -527,5 +529,26 @@ TEST(OpDefEqualityTest, EqualAndHash) {
|
|||||||
ExpectDifferent(o1, o3);
|
ExpectDifferent(o1, o3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(OpDefAttrDefaultsUnchangedTest, Foo) {
|
||||||
|
const auto& op1 = FromText("name: 'op1' attr { name: 'n' type: 'string'}");
|
||||||
|
const auto& op2 = FromText(
|
||||||
|
"name: 'op2' attr { name: 'n' type: 'string' default_value: {s: 'x'}}");
|
||||||
|
const auto& op3 = FromText(
|
||||||
|
"name: 'op3' attr { name: 'n' type: 'string' default_value: {s: 'y'}}");
|
||||||
|
|
||||||
|
// Adding a default value: fine.
|
||||||
|
TF_EXPECT_OK(OpDefAttrDefaultsUnchanged(op1, op2));
|
||||||
|
|
||||||
|
// Changing a default value: not ok.
|
||||||
|
Status changed_attr = OpDefAttrDefaultsUnchanged(op2, op3);
|
||||||
|
ExpectFailure(changed_attr,
|
||||||
|
"Attr 'n' has changed it's default value; from \"x\" to \"y\"");
|
||||||
|
|
||||||
|
// Removing a default value: not ok.
|
||||||
|
Status removed_attr = OpDefAttrDefaultsUnchanged(op2, op1);
|
||||||
|
ExpectFailure(removed_attr,
|
||||||
|
"Attr 'n' has removed it's default; from \"x\" to no default");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -198,7 +198,7 @@ Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
|
|||||||
TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op));
|
TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify default value of attrs has not been added/removed/modified
|
// Verify default value of attrs has not been removed or modified
|
||||||
// as compared to only the last historical version.
|
// as compared to only the last historical version.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));
|
OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));
|
||||||
|
Loading…
Reference in New Issue
Block a user