Update OpDefBuilder with an option to allow attr type to be 'any'.

PiperOrigin-RevId: 336742667
Change-Id: Ie88dfa0960315225a33972e5d4227c44f35f25e9
This commit is contained in:
Edward Loper 2020-10-12 14:22:34 -07:00 committed by TensorFlower Gardener
parent b082980a84
commit ade220bef9
2 changed files with 14 additions and 2 deletions

View File

@ -145,7 +145,7 @@ bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
return true;
}
void FinalizeAttr(StringPiece spec, OpDef* op_def,
void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def,
std::vector<string>* errors) {
OpDef::AttrDef* attr = op_def->add_attr();
StringPiece orig(spec);
@ -175,6 +175,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
type = "tensor";
} else if (absl::ConsumePrefix(&spec, "func")) {
type = "func";
} else if (absl::ConsumePrefix(&spec, "any") && allow_attr_type_any) {
type = "any";
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
@ -633,13 +635,18 @@ OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) {
return *this;
}
OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() {
allow_attr_type_any_ = true;
return *this;
}
Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
std::vector<string> errors = errors_;
*op_reg_data = op_reg_data_;
OpDef* op_def = &op_reg_data->op_def;
for (StringPiece attr : attrs_) {
FinalizeAttr(attr, op_def, &errors);
FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors);
}
for (StringPiece input : inputs_) {
FinalizeInputOrOutput(input, false, op_def, &errors);

View File

@ -142,6 +142,10 @@ class OpDefBuilder {
// python/framework/common_shapes.py
OpDefBuilder& SetShapeFn(OpShapeInferenceFn fn);
// Allows the `<type>` in calls to `Attr()` to be "any".
// This is used by PythonAPIWrapper for pass-through parameters.
OpDefBuilder& AllowAttrTypeAny();
// Sets op_reg_data->op_def to the requested OpDef and
// op_reg_data->shape_inference_fn to the requested shape inference function,
// or returns an error.
@ -168,6 +172,7 @@ class OpDefBuilder {
std::vector<string> control_outputs_;
std::string doc_;
std::vector<string> errors_;
bool allow_attr_type_any_ = false;
};
} // namespace tensorflow