Add a .Deprecated method to REGISTER_OP

This replaces the OP_DEPRECATED macro with something declarative, which in
particular lets us throw exceptions at graph construction time based on
deprecation.  I've left the OP_DEPRECATED macro around in case uses elsewhere
can't be expressed in a purely declarative manner.
Change: 120386133
This commit is contained in:
Geoffrey Irving 2016-04-20 14:58:57 -08:00 committed by TensorFlower Gardener
parent c3af083a64
commit 30334d28c0
20 changed files with 111 additions and 10 deletions

View File

@ -209,6 +209,10 @@ class OpDefBuilderWrapper<true> {
builder_.SetAllowsUninitializedInput();
return *this;
}
OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
builder_.Deprecated(version, explanation);
return *this;
}
OpDefBuilderWrapper<true>& Doc(StringPiece text) {
builder_.Doc(text);
return *this;
@ -231,6 +235,7 @@ class OpDefBuilderWrapper<false> {
OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
};

View File

@ -94,6 +94,9 @@ message OpDef {
}
repeated AttrDef attr = 4;
// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;
// One-line human-readable description of what the Op does.
string summary = 5;
@ -139,6 +142,15 @@ message OpDef {
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;
// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};
// A collection of OpDefs
message OpList {
repeated OpDef op = 1;

View File

@ -541,6 +541,18 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
return *this;
}
OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
if (op_def_.has_deprecation()) {
errors_.push_back(
strings::StrCat("Deprecated called twice for Op ", op_def_.name()));
} else {
OpDeprecation* deprecation = op_def_.mutable_deprecation();
deprecation->set_version(version);
deprecation->set_explanation(explanation.ToString());
}
return *this;
}
Status OpDefBuilder::Finalize(OpDef* op_def) const {
std::vector<string> errors = errors_;
*op_def = op_def_;

View File

@ -89,6 +89,9 @@ class OpDefBuilder {
OpDefBuilder& SetIsStateful();
OpDefBuilder& SetAllowsUninitializedInput();
// Deprecate the op at a certain GraphDef version.
OpDefBuilder& Deprecated(int version, StringPiece explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
// <1-line summary>

View File

@ -561,7 +561,7 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
return Status::OK();
}
void RemoveDescriptionsFromOpDef(OpDef* op_def) {
void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
for (int i = 0; i < op_def->input_arg_size(); ++i) {
op_def->mutable_input_arg(i)->clear_description();
}
@ -575,6 +575,13 @@ void RemoveDescriptionsFromOpDef(OpDef* op_def) {
op_def->clear_description();
}
void RemoveDescriptionsFromOpDef(OpDef* op_def) {
RemoveNonDeprecationDescriptionsFromOpDef(op_def);
if (op_def->has_deprecation()) {
op_def->mutable_deprecation()->clear_explanation();
}
}
void RemoveDescriptionsFromOpList(OpList* op_list) {
for (int i = 0; i < op_list->op_size(); ++i) {
OpDef* op_def = op_list->mutable_op(i);

View File

@ -58,6 +58,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
void RemoveDescriptionsFromOpDef(OpDef* op_def);
void RemoveDescriptionsFromOpList(OpList* op_list);
// Remove docs from *op_def but leave explanations of deprecations.
void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def);
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_

View File

@ -89,6 +89,10 @@ OpKernel::OpKernel(OpKernelConstruction* context)
OP_REQUIRES_OK(context,
NameRangesForNode(def_, context->op_def(), &input_name_map_,
&output_name_map_));
if (context->op_def().has_deprecation()) {
const OpDeprecation& deprecation = context->op_def().deprecation();
OP_DEPRECATED(context, deprecation.version(), deprecation.explanation());
}
}
OpKernel::~OpKernel() {}

View File

@ -1252,6 +1252,11 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
// }
// Declares an op deprecated, and illegal starting at GraphDef version VERSION
// Cleverly, OP_DEPRECATED is itself deprecated for most users; instead, use
// REGISTER_OP(...)
// ...
// .Deprecated(version, note)
// ...
#define OP_DEPRECATED(CTX, VERSION, NOTE) \
if ((CTX)->graph_def_version() >= (VERSION)) { \
::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \

View File

@ -38,7 +38,6 @@ template <typename Device, typename T>
class AdjustContrastOp : public OpKernel {
public:
explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
OP_DEPRECATED(context, 2, "Use AdjustContrastv2 instead");
}
void Compute(OpKernelContext* context) override {

View File

@ -33,7 +33,6 @@ template <typename Device, typename T>
class BatchNormOp : public OpKernel {
public:
explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
OP_REQUIRES_OK(context,
context->GetAttr("variance_epsilon", &variance_epsilon_));
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
@ -82,7 +81,6 @@ template <typename Device, typename T>
class BatchNormGradOp : public OpKernel {
public:
explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
OP_REQUIRES_OK(context,
context->GetAttr("variance_epsilon", &variance_epsilon_));
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",

View File

@ -28,7 +28,6 @@ template <typename T>
class RandomCropOp : public OpKernel {
public:
explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
OP_DEPRECATED(context, 8, "Random crop is now pure Python");
OP_REQUIRES_OK(context, generator_.Init(context));
}

View File

@ -186,9 +186,7 @@ HANDLE_CASE_DIM(GPUDevice, DT_INT64);
template <typename Device>
class TileGradientOp : public OpKernel {
public:
explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {
OP_DEPRECATED(context, 3, "TileGrad has been replaced with reduce_sum");
}
explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);

View File

@ -33,7 +33,6 @@ class TopK : public OpKernel {
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
if (num_inputs() < 2) { // k is an attr (TopK).
OP_DEPRECATED(context, 7, "Use TopKV2 instead");
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
} else { // k is an input (TopKV2), so we won't know it until Compute.
k_ = -1;

View File

@ -993,6 +993,7 @@ REGISTER_OP("TileGrad")
.Input("multiples: int32")
.Output("output: T")
.Attr("T: type")
.Deprecated(3, "TileGrad has been replaced with reduce_sum")
.Doc(R"doc(
Returns the gradient of `Tile`.

View File

@ -153,6 +153,7 @@ REGISTER_OP("RandomCrop")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetIsStateful()
.Deprecated(8, "Random crop is now pure Python")
.Doc(R"doc(
Randomly crop `image`.
@ -267,6 +268,7 @@ REGISTER_OP("AdjustContrast")
.Input("max_value: float")
.Output("output: float")
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Deprecated(2, "Use AdjustContrastv2 instead")
.Doc(R"Doc(
Deprecated. Disallowed in GraphDef version >= 2.
)Doc");

View File

@ -85,6 +85,7 @@ REGISTER_OP("BatchNormWithGlobalNormalization")
.Attr("T: numbertype")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
.Doc(R"doc(
Batch normalization.
@ -121,6 +122,7 @@ REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
.Attr("T: numbertype")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
.Doc(R"doc(
Gradients for batch normalization.
@ -815,6 +817,7 @@ REGISTER_OP("TopK")
.Attr("k: int >= 0")
.Attr("sorted: bool = true")
.Attr("T: realnumbertype")
.Deprecated(7, "Use TopKV2 instead")
.Doc(R"doc(
Finds values and indices of the `k` largest elements for the last dimension.

View File

@ -1353,5 +1353,34 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
self.assertEqual("/device:CPU:0", b.device)
class DeprecatedTest(test_util.TensorFlowTestCase):
def testSuccess(self):
with ops.Graph().as_default() as g:
g.graph_def_versions.producer = 7
old = test_ops.old()
with self.test_session(graph=g):
old.run()
def _error(self):
return ((r"Op Old is not available in GraphDef version %d\. "
r"It has been removed in version 8\. For reasons\.") %
versions.GRAPH_DEF_VERSION)
def testGraphConstructionFail(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(NotImplementedError, self._error()):
test_ops.old()
def testGraphExecutionFail(self):
with ops.Graph().as_default() as g:
g.graph_def_versions.producer = 7
old = test_ops.old()
g.graph_def_versions.producer = versions.GRAPH_DEF_VERSION
with self.test_session(graph=g):
with self.assertRaisesOpError(self._error()):
old.run()
if __name__ == "__main__":
googletest.main()

View File

@ -661,7 +661,7 @@ from tensorflow.python.ops import op_def_library
auto added = out->Add();
*added = op_def;
RemoveDescriptionsFromOpDef(added);
RemoveNonDeprecationDescriptionsFromOpDef(added);
}
strings::Appendf(&result, R"(def _InitOpDefLibrary():

View File

@ -23,6 +23,8 @@ REGISTER_OP("KernelLabel").Output("result: string");
REGISTER_OP("GraphDefVersion").Output("version: int32").SetIsStateful();
REGISTER_OP("Old").Deprecated(8, "For reasons");
namespace {
enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
} // namespace
@ -79,4 +81,13 @@ class GraphDefVersionOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
GraphDefVersionOp);
class OldOp : public OpKernel {
public:
OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {}
};
REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp);
} // end namespace tensorflow

View File

@ -340,6 +340,17 @@ class OpDefLibrary(object):
if name is None:
name = op_type_name
# Check for deprecation
deprecation_version = op_def.deprecation.version
if deprecation_version:
producer = g.graph_def_versions.producer
if producer >= deprecation_version:
raise NotImplementedError(
("Op %s is not available in GraphDef version %d. "
"It has been removed in version %d. %s.") %
(op_type_name, producer, deprecation_version,
op_def.deprecation.explanation))
# Requires that op_def has passed validation (using the C++
# ValidateOpDef() from ../framework/op_def_util.h).
attrs = {}