Update backwards-compatibility checks to allow adding a new default to an op attribute.
PiperOrigin-RevId: 272084965
This commit is contained in:
parent
86902a8ada
commit
714f1dbd7a
@ -1051,7 +1051,7 @@ TEST_F(OpCompatibilityTest, RenameOutputListFails) {
|
||||
"Output signature mismatch 'old:T' vs. 'new:T'");
|
||||
}
|
||||
|
||||
// Should not be able to add a default to an attr.
|
||||
// It's ok to add a default to an attr if it doesn't already have one.
|
||||
REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234");
|
||||
REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel);
|
||||
|
||||
@ -1064,9 +1064,8 @@ TEST_F(OpCompatibilityTest, AddDefault) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def)
|
||||
.Attr("a", 765)
|
||||
.Finalize(node_def()));
|
||||
ExpectDefaultChangeFailure(
|
||||
old_op.op_def,
|
||||
"Attr 'a' has added/removed it's default; from no default to 1234");
|
||||
ExpectSuccess(old_op.op_def);
|
||||
EXPECT_EQ("{{node add_default}} = AddDefault[a=765]()", Result());
|
||||
}
|
||||
|
||||
// Should not be able to remove a default from an attr.
|
||||
@ -1083,7 +1082,7 @@ TEST_F(OpCompatibilityTest, RemoveDefault) {
|
||||
NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def()));
|
||||
ExpectDefaultChangeFailure(
|
||||
old_op.op_def,
|
||||
"Attr 'a' has added/removed it's default; from 91 to no default");
|
||||
"Attr 'a' has removed it's default; from 91 to no default");
|
||||
}
|
||||
|
||||
// Should not be able to change a default for an attr.
|
||||
|
@ -734,10 +734,13 @@ Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
|
||||
const OpDef::AttrDef* new_attr =
|
||||
gtl::FindPtrOrNull(new_attrs, old_attr.name());
|
||||
if (new_attr == nullptr) continue;
|
||||
if (old_attr.has_default_value() != new_attr->has_default_value()) {
|
||||
if (new_attr->has_default_value() && !old_attr.has_default_value()) {
|
||||
continue; // Adding new default values is safe.
|
||||
}
|
||||
if (old_attr.has_default_value() && !new_attr->has_default_value()) {
|
||||
return errors::InvalidArgument(
|
||||
"Attr '", old_attr.name(), "' has added/removed it's default; ",
|
||||
"from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
|
||||
"Attr '", old_attr.name(), "' has removed it's default; ", "from ",
|
||||
DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
|
||||
}
|
||||
if (old_attr.has_default_value() &&
|
||||
!AreAttrValuesEqual(old_attr.default_value(),
|
||||
|
@ -68,8 +68,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
|
||||
const OpDef& penultimate_op,
|
||||
const OpDef& new_op);
|
||||
|
||||
// Returns an error if the default value for any attr is added/removed/modified
|
||||
// in new_op compared to old_op.
|
||||
// Returns an error if the default value for any attr is removed or modified
|
||||
// in new_op compared to old_op. Adding new default values is safe, and does
|
||||
// not raise an error.
|
||||
Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op);
|
||||
|
||||
// Remove all docs from *op_def / *op_list.
|
||||
|
@ -53,17 +53,19 @@ class ValidateOpDefTest : public ::testing::Test {
|
||||
return ValidateOpDef(op_reg_data.op_def);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpectFailure(const Status& status, const string& message) {
|
||||
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
|
||||
if (!status.ok()) {
|
||||
LOG(INFO) << "message: " << status;
|
||||
EXPECT_TRUE(absl::StrContains(status.ToString(), message))
|
||||
<< "Actual: " << status << "\nExpected to contain: " << message;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
void ExpectFailure(const Status& status, const string& message) {
|
||||
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
|
||||
if (!status.ok()) {
|
||||
LOG(INFO) << "message: " << status;
|
||||
EXPECT_TRUE(absl::StrContains(status.ToString(), message))
|
||||
<< "Actual: " << status << "\nExpected to contain: " << message;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(ValidateOpDefTest, OpDefValid) {
|
||||
TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int")));
|
||||
TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32")));
|
||||
@ -527,5 +529,26 @@ TEST(OpDefEqualityTest, EqualAndHash) {
|
||||
ExpectDifferent(o1, o3);
|
||||
}
|
||||
|
||||
TEST(OpDefAttrDefaultsUnchangedTest, Foo) {
|
||||
const auto& op1 = FromText("name: 'op1' attr { name: 'n' type: 'string'}");
|
||||
const auto& op2 = FromText(
|
||||
"name: 'op2' attr { name: 'n' type: 'string' default_value: {s: 'x'}}");
|
||||
const auto& op3 = FromText(
|
||||
"name: 'op3' attr { name: 'n' type: 'string' default_value: {s: 'y'}}");
|
||||
|
||||
// Adding a default value: fine.
|
||||
TF_EXPECT_OK(OpDefAttrDefaultsUnchanged(op1, op2));
|
||||
|
||||
// Changing a default value: not ok.
|
||||
Status changed_attr = OpDefAttrDefaultsUnchanged(op2, op3);
|
||||
ExpectFailure(changed_attr,
|
||||
"Attr 'n' has changed it's default value; from \"x\" to \"y\"");
|
||||
|
||||
// Removing a default value: not ok.
|
||||
Status removed_attr = OpDefAttrDefaultsUnchanged(op2, op1);
|
||||
ExpectFailure(removed_attr,
|
||||
"Attr 'n' has removed it's default; from \"x\" to no default");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -198,7 +198,7 @@ Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
|
||||
TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op));
|
||||
}
|
||||
|
||||
// Verify default value of attrs has not been added/removed/modified
|
||||
// Verify default value of attrs has not been removed or modified
|
||||
// as compared to only the last historical version.
|
||||
TF_RETURN_IF_ERROR(
|
||||
OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));
|
||||
|
Loading…
Reference in New Issue
Block a user