Add a .Deprecated method to REGISTER_OP
This replaces the OP_DEPRECATED macro with something declarative, which in particular lets us throw exceptions at graph construction time based on deprecation. I've left the OP_DEPRECATED macro around in case uses elsewhere can't be expressed in a purely declarative manner. Change: 120386133
This commit is contained in:
parent
c3af083a64
commit
30334d28c0
@ -209,6 +209,10 @@ class OpDefBuilderWrapper<true> {
|
||||
builder_.SetAllowsUninitializedInput();
|
||||
return *this;
|
||||
}
|
||||
OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
|
||||
builder_.Deprecated(version, explanation);
|
||||
return *this;
|
||||
}
|
||||
OpDefBuilderWrapper<true>& Doc(StringPiece text) {
|
||||
builder_.Doc(text);
|
||||
return *this;
|
||||
@ -231,6 +235,7 @@ class OpDefBuilderWrapper<false> {
|
||||
OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
|
||||
OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
|
||||
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
|
||||
OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
|
||||
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
|
||||
};
|
||||
|
||||
|
@ -94,6 +94,9 @@ message OpDef {
|
||||
}
|
||||
repeated AttrDef attr = 4;
|
||||
|
||||
// Optional deprecation based on GraphDef versions.
|
||||
OpDeprecation deprecation = 8;
|
||||
|
||||
// One-line human-readable description of what the Op does.
|
||||
string summary = 5;
|
||||
|
||||
@ -139,6 +142,15 @@ message OpDef {
|
||||
bool allows_uninitialized_input = 19; // for Assign, etc.
|
||||
};
|
||||
|
||||
// Information about version-dependent deprecation of an op
|
||||
message OpDeprecation {
|
||||
// First GraphDef version at which the op is disallowed.
|
||||
int32 version = 1;
|
||||
|
||||
// Explanation of why it was deprecated and what to use instead.
|
||||
string explanation = 2;
|
||||
};
|
||||
|
||||
// A collection of OpDefs
|
||||
message OpList {
|
||||
repeated OpDef op = 1;
|
||||
|
@ -541,6 +541,18 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
|
||||
if (op_def_.has_deprecation()) {
|
||||
errors_.push_back(
|
||||
strings::StrCat("Deprecated called twice for Op ", op_def_.name()));
|
||||
} else {
|
||||
OpDeprecation* deprecation = op_def_.mutable_deprecation();
|
||||
deprecation->set_version(version);
|
||||
deprecation->set_explanation(explanation.ToString());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status OpDefBuilder::Finalize(OpDef* op_def) const {
|
||||
std::vector<string> errors = errors_;
|
||||
*op_def = op_def_;
|
||||
|
@ -89,6 +89,9 @@ class OpDefBuilder {
|
||||
OpDefBuilder& SetIsStateful();
|
||||
OpDefBuilder& SetAllowsUninitializedInput();
|
||||
|
||||
// Deprecate the op at a certain GraphDef version.
|
||||
OpDefBuilder& Deprecated(int version, StringPiece explanation);
|
||||
|
||||
// Adds docs to this OpDefBuilder (and returns *this).
|
||||
// Docs have the format:
|
||||
// <1-line summary>
|
||||
|
@ -561,7 +561,7 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RemoveDescriptionsFromOpDef(OpDef* op_def) {
|
||||
void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
|
||||
for (int i = 0; i < op_def->input_arg_size(); ++i) {
|
||||
op_def->mutable_input_arg(i)->clear_description();
|
||||
}
|
||||
@ -575,6 +575,13 @@ void RemoveDescriptionsFromOpDef(OpDef* op_def) {
|
||||
op_def->clear_description();
|
||||
}
|
||||
|
||||
void RemoveDescriptionsFromOpDef(OpDef* op_def) {
|
||||
RemoveNonDeprecationDescriptionsFromOpDef(op_def);
|
||||
if (op_def->has_deprecation()) {
|
||||
op_def->mutable_deprecation()->clear_explanation();
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveDescriptionsFromOpList(OpList* op_list) {
|
||||
for (int i = 0; i < op_list->op_size(); ++i) {
|
||||
OpDef* op_def = op_list->mutable_op(i);
|
||||
|
@ -58,6 +58,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
|
||||
void RemoveDescriptionsFromOpDef(OpDef* op_def);
|
||||
void RemoveDescriptionsFromOpList(OpList* op_list);
|
||||
|
||||
// Remove docs from *op_def but leave explanations of deprecations.
|
||||
void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
|
||||
|
@ -89,6 +89,10 @@ OpKernel::OpKernel(OpKernelConstruction* context)
|
||||
OP_REQUIRES_OK(context,
|
||||
NameRangesForNode(def_, context->op_def(), &input_name_map_,
|
||||
&output_name_map_));
|
||||
if (context->op_def().has_deprecation()) {
|
||||
const OpDeprecation& deprecation = context->op_def().deprecation();
|
||||
OP_DEPRECATED(context, deprecation.version(), deprecation.explanation());
|
||||
}
|
||||
}
|
||||
|
||||
OpKernel::~OpKernel() {}
|
||||
|
@ -1252,6 +1252,11 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
|
||||
// }
|
||||
|
||||
// Declares an op deprecated, and illegal starting at GraphDef version VERSION
|
||||
// Cleverly, OP_DEPRECATED is itself deprecated for most users; instead, use
|
||||
// REGISTER_OP(...)
|
||||
// ...
|
||||
// .Deprecated(version, note)
|
||||
// ...
|
||||
#define OP_DEPRECATED(CTX, VERSION, NOTE) \
|
||||
if ((CTX)->graph_def_version() >= (VERSION)) { \
|
||||
::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
|
||||
|
@ -38,7 +38,6 @@ template <typename Device, typename T>
|
||||
class AdjustContrastOp : public OpKernel {
|
||||
public:
|
||||
explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_DEPRECATED(context, 2, "Use AdjustContrastv2 instead");
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
|
@ -33,7 +33,6 @@ template <typename Device, typename T>
|
||||
class BatchNormOp : public OpKernel {
|
||||
public:
|
||||
explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("variance_epsilon", &variance_epsilon_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
|
||||
@ -82,7 +81,6 @@ template <typename Device, typename T>
|
||||
class BatchNormGradOp : public OpKernel {
|
||||
public:
|
||||
explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("variance_epsilon", &variance_epsilon_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
|
||||
|
@ -28,7 +28,6 @@ template <typename T>
|
||||
class RandomCropOp : public OpKernel {
|
||||
public:
|
||||
explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_DEPRECATED(context, 8, "Random crop is now pure Python");
|
||||
OP_REQUIRES_OK(context, generator_.Init(context));
|
||||
}
|
||||
|
||||
|
@ -186,9 +186,7 @@ HANDLE_CASE_DIM(GPUDevice, DT_INT64);
|
||||
template <typename Device>
|
||||
class TileGradientOp : public OpKernel {
|
||||
public:
|
||||
explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_DEPRECATED(context, 3, "TileGrad has been replaced with reduce_sum");
|
||||
}
|
||||
explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
|
@ -33,7 +33,6 @@ class TopK : public OpKernel {
|
||||
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
|
||||
if (num_inputs() < 2) { // k is an attr (TopK).
|
||||
OP_DEPRECATED(context, 7, "Use TopKV2 instead");
|
||||
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
|
||||
} else { // k is an input (TopKV2), so we won't know it until Compute.
|
||||
k_ = -1;
|
||||
|
@ -993,6 +993,7 @@ REGISTER_OP("TileGrad")
|
||||
.Input("multiples: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Deprecated(3, "TileGrad has been replaced with reduce_sum")
|
||||
.Doc(R"doc(
|
||||
Returns the gradient of `Tile`.
|
||||
|
||||
|
@ -153,6 +153,7 @@ REGISTER_OP("RandomCrop")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.SetIsStateful()
|
||||
.Deprecated(8, "Random crop is now pure Python")
|
||||
.Doc(R"doc(
|
||||
Randomly crop `image`.
|
||||
|
||||
@ -267,6 +268,7 @@ REGISTER_OP("AdjustContrast")
|
||||
.Input("max_value: float")
|
||||
.Output("output: float")
|
||||
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
|
||||
.Deprecated(2, "Use AdjustContrastv2 instead")
|
||||
.Doc(R"Doc(
|
||||
Deprecated. Disallowed in GraphDef version >= 2.
|
||||
)Doc");
|
||||
|
@ -85,6 +85,7 @@ REGISTER_OP("BatchNormWithGlobalNormalization")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("variance_epsilon: float")
|
||||
.Attr("scale_after_normalization: bool")
|
||||
.Deprecated(9, "Use tf.nn.batch_normalization()")
|
||||
.Doc(R"doc(
|
||||
Batch normalization.
|
||||
|
||||
@ -121,6 +122,7 @@ REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("variance_epsilon: float")
|
||||
.Attr("scale_after_normalization: bool")
|
||||
.Deprecated(9, "Use tf.nn.batch_normalization()")
|
||||
.Doc(R"doc(
|
||||
Gradients for batch normalization.
|
||||
|
||||
@ -815,6 +817,7 @@ REGISTER_OP("TopK")
|
||||
.Attr("k: int >= 0")
|
||||
.Attr("sorted: bool = true")
|
||||
.Attr("T: realnumbertype")
|
||||
.Deprecated(7, "Use TopKV2 instead")
|
||||
.Doc(R"doc(
|
||||
Finds values and indices of the `k` largest elements for the last dimension.
|
||||
|
||||
|
@ -1353,5 +1353,34 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual("/device:CPU:0", b.device)
|
||||
|
||||
|
||||
class DeprecatedTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testSuccess(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
g.graph_def_versions.producer = 7
|
||||
old = test_ops.old()
|
||||
with self.test_session(graph=g):
|
||||
old.run()
|
||||
|
||||
def _error(self):
|
||||
return ((r"Op Old is not available in GraphDef version %d\. "
|
||||
r"It has been removed in version 8\. For reasons\.") %
|
||||
versions.GRAPH_DEF_VERSION)
|
||||
|
||||
def testGraphConstructionFail(self):
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegexp(NotImplementedError, self._error()):
|
||||
test_ops.old()
|
||||
|
||||
def testGraphExecutionFail(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
g.graph_def_versions.producer = 7
|
||||
old = test_ops.old()
|
||||
g.graph_def_versions.producer = versions.GRAPH_DEF_VERSION
|
||||
with self.test_session(graph=g):
|
||||
with self.assertRaisesOpError(self._error()):
|
||||
old.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -661,7 +661,7 @@ from tensorflow.python.ops import op_def_library
|
||||
|
||||
auto added = out->Add();
|
||||
*added = op_def;
|
||||
RemoveDescriptionsFromOpDef(added);
|
||||
RemoveNonDeprecationDescriptionsFromOpDef(added);
|
||||
}
|
||||
|
||||
strings::Appendf(&result, R"(def _InitOpDefLibrary():
|
||||
|
@ -23,6 +23,8 @@ REGISTER_OP("KernelLabel").Output("result: string");
|
||||
|
||||
REGISTER_OP("GraphDefVersion").Output("version: int32").SetIsStateful();
|
||||
|
||||
REGISTER_OP("Old").Deprecated(8, "For reasons");
|
||||
|
||||
namespace {
|
||||
enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
|
||||
} // namespace
|
||||
@ -79,4 +81,13 @@ class GraphDefVersionOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
|
||||
GraphDefVersionOp);
|
||||
|
||||
class OldOp : public OpKernel {
|
||||
public:
|
||||
OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -340,6 +340,17 @@ class OpDefLibrary(object):
|
||||
if name is None:
|
||||
name = op_type_name
|
||||
|
||||
# Check for deprecation
|
||||
deprecation_version = op_def.deprecation.version
|
||||
if deprecation_version:
|
||||
producer = g.graph_def_versions.producer
|
||||
if producer >= deprecation_version:
|
||||
raise NotImplementedError(
|
||||
("Op %s is not available in GraphDef version %d. "
|
||||
"It has been removed in version %d. %s.") %
|
||||
(op_type_name, producer, deprecation_version,
|
||||
op_def.deprecation.explanation))
|
||||
|
||||
# Requires that op_def has passed validation (using the C++
|
||||
# ValidateOpDef() from ../framework/op_def_util.h).
|
||||
attrs = {}
|
||||
|
Loading…
Reference in New Issue
Block a user