Add a .Deprecated method to REGISTER_OP
This replaces the OP_DEPRECATED macro with something declarative, which in particular lets us throw exceptions at graph construction time based on deprecation. I've left the OP_DEPRECATED macro around in case uses elsewhere can't be expressed in a purely declarative manner. Change: 120386133
This commit is contained in:
parent
c3af083a64
commit
30334d28c0
tensorflow
core
framework
kernels
ops
python
@ -209,6 +209,10 @@ class OpDefBuilderWrapper<true> {
|
|||||||
builder_.SetAllowsUninitializedInput();
|
builder_.SetAllowsUninitializedInput();
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
|
||||||
|
builder_.Deprecated(version, explanation);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
OpDefBuilderWrapper<true>& Doc(StringPiece text) {
|
OpDefBuilderWrapper<true>& Doc(StringPiece text) {
|
||||||
builder_.Doc(text);
|
builder_.Doc(text);
|
||||||
return *this;
|
return *this;
|
||||||
@ -231,6 +235,7 @@ class OpDefBuilderWrapper<false> {
|
|||||||
OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
|
OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
|
||||||
OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
|
OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
|
||||||
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
|
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
|
||||||
|
OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
|
||||||
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
|
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -94,6 +94,9 @@ message OpDef {
|
|||||||
}
|
}
|
||||||
repeated AttrDef attr = 4;
|
repeated AttrDef attr = 4;
|
||||||
|
|
||||||
|
// Optional deprecation based on GraphDef versions.
|
||||||
|
OpDeprecation deprecation = 8;
|
||||||
|
|
||||||
// One-line human-readable description of what the Op does.
|
// One-line human-readable description of what the Op does.
|
||||||
string summary = 5;
|
string summary = 5;
|
||||||
|
|
||||||
@ -139,6 +142,15 @@ message OpDef {
|
|||||||
bool allows_uninitialized_input = 19; // for Assign, etc.
|
bool allows_uninitialized_input = 19; // for Assign, etc.
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Information about version-dependent deprecation of an op
|
||||||
|
message OpDeprecation {
|
||||||
|
// First GraphDef version at which the op is disallowed.
|
||||||
|
int32 version = 1;
|
||||||
|
|
||||||
|
// Explanation of why it was deprecated and what to use instead.
|
||||||
|
string explanation = 2;
|
||||||
|
};
|
||||||
|
|
||||||
// A collection of OpDefs
|
// A collection of OpDefs
|
||||||
message OpList {
|
message OpList {
|
||||||
repeated OpDef op = 1;
|
repeated OpDef op = 1;
|
||||||
|
@ -541,6 +541,18 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
|
||||||
|
if (op_def_.has_deprecation()) {
|
||||||
|
errors_.push_back(
|
||||||
|
strings::StrCat("Deprecated called twice for Op ", op_def_.name()));
|
||||||
|
} else {
|
||||||
|
OpDeprecation* deprecation = op_def_.mutable_deprecation();
|
||||||
|
deprecation->set_version(version);
|
||||||
|
deprecation->set_explanation(explanation.ToString());
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
Status OpDefBuilder::Finalize(OpDef* op_def) const {
|
Status OpDefBuilder::Finalize(OpDef* op_def) const {
|
||||||
std::vector<string> errors = errors_;
|
std::vector<string> errors = errors_;
|
||||||
*op_def = op_def_;
|
*op_def = op_def_;
|
||||||
|
@ -89,6 +89,9 @@ class OpDefBuilder {
|
|||||||
OpDefBuilder& SetIsStateful();
|
OpDefBuilder& SetIsStateful();
|
||||||
OpDefBuilder& SetAllowsUninitializedInput();
|
OpDefBuilder& SetAllowsUninitializedInput();
|
||||||
|
|
||||||
|
// Deprecate the op at a certain GraphDef version.
|
||||||
|
OpDefBuilder& Deprecated(int version, StringPiece explanation);
|
||||||
|
|
||||||
// Adds docs to this OpDefBuilder (and returns *this).
|
// Adds docs to this OpDefBuilder (and returns *this).
|
||||||
// Docs have the format:
|
// Docs have the format:
|
||||||
// <1-line summary>
|
// <1-line summary>
|
||||||
|
@ -561,7 +561,7 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveDescriptionsFromOpDef(OpDef* op_def) {
|
void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
|
||||||
for (int i = 0; i < op_def->input_arg_size(); ++i) {
|
for (int i = 0; i < op_def->input_arg_size(); ++i) {
|
||||||
op_def->mutable_input_arg(i)->clear_description();
|
op_def->mutable_input_arg(i)->clear_description();
|
||||||
}
|
}
|
||||||
@ -575,6 +575,13 @@ void RemoveDescriptionsFromOpDef(OpDef* op_def) {
|
|||||||
op_def->clear_description();
|
op_def->clear_description();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RemoveDescriptionsFromOpDef(OpDef* op_def) {
|
||||||
|
RemoveNonDeprecationDescriptionsFromOpDef(op_def);
|
||||||
|
if (op_def->has_deprecation()) {
|
||||||
|
op_def->mutable_deprecation()->clear_explanation();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void RemoveDescriptionsFromOpList(OpList* op_list) {
|
void RemoveDescriptionsFromOpList(OpList* op_list) {
|
||||||
for (int i = 0; i < op_list->op_size(); ++i) {
|
for (int i = 0; i < op_list->op_size(); ++i) {
|
||||||
OpDef* op_def = op_list->mutable_op(i);
|
OpDef* op_def = op_list->mutable_op(i);
|
||||||
|
@ -58,6 +58,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
|
|||||||
void RemoveDescriptionsFromOpDef(OpDef* op_def);
|
void RemoveDescriptionsFromOpDef(OpDef* op_def);
|
||||||
void RemoveDescriptionsFromOpList(OpList* op_list);
|
void RemoveDescriptionsFromOpList(OpList* op_list);
|
||||||
|
|
||||||
|
// Remove docs from *op_def but leave explanations of deprecations.
|
||||||
|
void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
|
#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
|
||||||
|
@ -89,6 +89,10 @@ OpKernel::OpKernel(OpKernelConstruction* context)
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
NameRangesForNode(def_, context->op_def(), &input_name_map_,
|
NameRangesForNode(def_, context->op_def(), &input_name_map_,
|
||||||
&output_name_map_));
|
&output_name_map_));
|
||||||
|
if (context->op_def().has_deprecation()) {
|
||||||
|
const OpDeprecation& deprecation = context->op_def().deprecation();
|
||||||
|
OP_DEPRECATED(context, deprecation.version(), deprecation.explanation());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OpKernel::~OpKernel() {}
|
OpKernel::~OpKernel() {}
|
||||||
|
@ -1252,6 +1252,11 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// Declares an op deprecated, and illegal starting at GraphDef version VERSION
|
// Declares an op deprecated, and illegal starting at GraphDef version VERSION
|
||||||
|
// Cleverly, OP_DEPRECATED is itself deprecated for most users; instead, use
|
||||||
|
// REGISTER_OP(...)
|
||||||
|
// ...
|
||||||
|
// .Deprecated(version, note)
|
||||||
|
// ...
|
||||||
#define OP_DEPRECATED(CTX, VERSION, NOTE) \
|
#define OP_DEPRECATED(CTX, VERSION, NOTE) \
|
||||||
if ((CTX)->graph_def_version() >= (VERSION)) { \
|
if ((CTX)->graph_def_version() >= (VERSION)) { \
|
||||||
::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
|
::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
|
||||||
|
@ -38,7 +38,6 @@ template <typename Device, typename T>
|
|||||||
class AdjustContrastOp : public OpKernel {
|
class AdjustContrastOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_DEPRECATED(context, 2, "Use AdjustContrastv2 instead");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
|
@ -33,7 +33,6 @@ template <typename Device, typename T>
|
|||||||
class BatchNormOp : public OpKernel {
|
class BatchNormOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
|
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("variance_epsilon", &variance_epsilon_));
|
context->GetAttr("variance_epsilon", &variance_epsilon_));
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
|
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
|
||||||
@ -82,7 +81,6 @@ template <typename Device, typename T>
|
|||||||
class BatchNormGradOp : public OpKernel {
|
class BatchNormGradOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
|
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("variance_epsilon", &variance_epsilon_));
|
context->GetAttr("variance_epsilon", &variance_epsilon_));
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
|
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
|
||||||
|
@ -28,7 +28,6 @@ template <typename T>
|
|||||||
class RandomCropOp : public OpKernel {
|
class RandomCropOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_DEPRECATED(context, 8, "Random crop is now pure Python");
|
|
||||||
OP_REQUIRES_OK(context, generator_.Init(context));
|
OP_REQUIRES_OK(context, generator_.Init(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,9 +186,7 @@ HANDLE_CASE_DIM(GPUDevice, DT_INT64);
|
|||||||
template <typename Device>
|
template <typename Device>
|
||||||
class TileGradientOp : public OpKernel {
|
class TileGradientOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
OP_DEPRECATED(context, 3, "TileGrad has been replaced with reduce_sum");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
|
@ -33,7 +33,6 @@ class TopK : public OpKernel {
|
|||||||
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
|
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
|
OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
|
||||||
if (num_inputs() < 2) { // k is an attr (TopK).
|
if (num_inputs() < 2) { // k is an attr (TopK).
|
||||||
OP_DEPRECATED(context, 7, "Use TopKV2 instead");
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
|
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
|
||||||
} else { // k is an input (TopKV2), so we won't know it until Compute.
|
} else { // k is an input (TopKV2), so we won't know it until Compute.
|
||||||
k_ = -1;
|
k_ = -1;
|
||||||
|
@ -993,6 +993,7 @@ REGISTER_OP("TileGrad")
|
|||||||
.Input("multiples: int32")
|
.Input("multiples: int32")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
|
.Deprecated(3, "TileGrad has been replaced with reduce_sum")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Returns the gradient of `Tile`.
|
Returns the gradient of `Tile`.
|
||||||
|
|
||||||
|
@ -153,6 +153,7 @@ REGISTER_OP("RandomCrop")
|
|||||||
.Attr("seed: int = 0")
|
.Attr("seed: int = 0")
|
||||||
.Attr("seed2: int = 0")
|
.Attr("seed2: int = 0")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.Deprecated(8, "Random crop is now pure Python")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Randomly crop `image`.
|
Randomly crop `image`.
|
||||||
|
|
||||||
@ -267,6 +268,7 @@ REGISTER_OP("AdjustContrast")
|
|||||||
.Input("max_value: float")
|
.Input("max_value: float")
|
||||||
.Output("output: float")
|
.Output("output: float")
|
||||||
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
|
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
|
||||||
|
.Deprecated(2, "Use AdjustContrastv2 instead")
|
||||||
.Doc(R"Doc(
|
.Doc(R"Doc(
|
||||||
Deprecated. Disallowed in GraphDef version >= 2.
|
Deprecated. Disallowed in GraphDef version >= 2.
|
||||||
)Doc");
|
)Doc");
|
||||||
|
@ -85,6 +85,7 @@ REGISTER_OP("BatchNormWithGlobalNormalization")
|
|||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("variance_epsilon: float")
|
.Attr("variance_epsilon: float")
|
||||||
.Attr("scale_after_normalization: bool")
|
.Attr("scale_after_normalization: bool")
|
||||||
|
.Deprecated(9, "Use tf.nn.batch_normalization()")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Batch normalization.
|
Batch normalization.
|
||||||
|
|
||||||
@ -121,6 +122,7 @@ REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
|
|||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("variance_epsilon: float")
|
.Attr("variance_epsilon: float")
|
||||||
.Attr("scale_after_normalization: bool")
|
.Attr("scale_after_normalization: bool")
|
||||||
|
.Deprecated(9, "Use tf.nn.batch_normalization()")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Gradients for batch normalization.
|
Gradients for batch normalization.
|
||||||
|
|
||||||
@ -815,6 +817,7 @@ REGISTER_OP("TopK")
|
|||||||
.Attr("k: int >= 0")
|
.Attr("k: int >= 0")
|
||||||
.Attr("sorted: bool = true")
|
.Attr("sorted: bool = true")
|
||||||
.Attr("T: realnumbertype")
|
.Attr("T: realnumbertype")
|
||||||
|
.Deprecated(7, "Use TopKV2 instead")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Finds values and indices of the `k` largest elements for the last dimension.
|
Finds values and indices of the `k` largest elements for the last dimension.
|
||||||
|
|
||||||
|
@ -1353,5 +1353,34 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("/device:CPU:0", b.device)
|
self.assertEqual("/device:CPU:0", b.device)
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def testSuccess(self):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
g.graph_def_versions.producer = 7
|
||||||
|
old = test_ops.old()
|
||||||
|
with self.test_session(graph=g):
|
||||||
|
old.run()
|
||||||
|
|
||||||
|
def _error(self):
|
||||||
|
return ((r"Op Old is not available in GraphDef version %d\. "
|
||||||
|
r"It has been removed in version 8\. For reasons\.") %
|
||||||
|
versions.GRAPH_DEF_VERSION)
|
||||||
|
|
||||||
|
def testGraphConstructionFail(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with self.assertRaisesRegexp(NotImplementedError, self._error()):
|
||||||
|
test_ops.old()
|
||||||
|
|
||||||
|
def testGraphExecutionFail(self):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
g.graph_def_versions.producer = 7
|
||||||
|
old = test_ops.old()
|
||||||
|
g.graph_def_versions.producer = versions.GRAPH_DEF_VERSION
|
||||||
|
with self.test_session(graph=g):
|
||||||
|
with self.assertRaisesOpError(self._error()):
|
||||||
|
old.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -661,7 +661,7 @@ from tensorflow.python.ops import op_def_library
|
|||||||
|
|
||||||
auto added = out->Add();
|
auto added = out->Add();
|
||||||
*added = op_def;
|
*added = op_def;
|
||||||
RemoveDescriptionsFromOpDef(added);
|
RemoveNonDeprecationDescriptionsFromOpDef(added);
|
||||||
}
|
}
|
||||||
|
|
||||||
strings::Appendf(&result, R"(def _InitOpDefLibrary():
|
strings::Appendf(&result, R"(def _InitOpDefLibrary():
|
||||||
|
@ -23,6 +23,8 @@ REGISTER_OP("KernelLabel").Output("result: string");
|
|||||||
|
|
||||||
REGISTER_OP("GraphDefVersion").Output("version: int32").SetIsStateful();
|
REGISTER_OP("GraphDefVersion").Output("version: int32").SetIsStateful();
|
||||||
|
|
||||||
|
REGISTER_OP("Old").Deprecated(8, "For reasons");
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
|
enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -79,4 +81,13 @@ class GraphDefVersionOp : public OpKernel {
|
|||||||
REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
|
||||||
GraphDefVersionOp);
|
GraphDefVersionOp);
|
||||||
|
|
||||||
|
class OldOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp);
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -340,6 +340,17 @@ class OpDefLibrary(object):
|
|||||||
if name is None:
|
if name is None:
|
||||||
name = op_type_name
|
name = op_type_name
|
||||||
|
|
||||||
|
# Check for deprecation
|
||||||
|
deprecation_version = op_def.deprecation.version
|
||||||
|
if deprecation_version:
|
||||||
|
producer = g.graph_def_versions.producer
|
||||||
|
if producer >= deprecation_version:
|
||||||
|
raise NotImplementedError(
|
||||||
|
("Op %s is not available in GraphDef version %d. "
|
||||||
|
"It has been removed in version %d. %s.") %
|
||||||
|
(op_type_name, producer, deprecation_version,
|
||||||
|
op_def.deprecation.explanation))
|
||||||
|
|
||||||
# Requires that op_def has passed validation (using the C++
|
# Requires that op_def has passed validation (using the C++
|
||||||
# ValidateOpDef() from ../framework/op_def_util.h).
|
# ValidateOpDef() from ../framework/op_def_util.h).
|
||||||
attrs = {}
|
attrs = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user