Update OpDefBuilder with an option to allow attr type to be 'any'.
PiperOrigin-RevId: 336742667 Change-Id: Ie88dfa0960315225a33972e5d4227c44f35f25e9
This commit is contained in:
parent
b082980a84
commit
ade220bef9
@ -145,7 +145,7 @@ bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
||||
void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def,
|
||||
std::vector<string>* errors) {
|
||||
OpDef::AttrDef* attr = op_def->add_attr();
|
||||
StringPiece orig(spec);
|
||||
@ -175,6 +175,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
||||
type = "tensor";
|
||||
} else if (absl::ConsumePrefix(&spec, "func")) {
|
||||
type = "func";
|
||||
} else if (absl::ConsumePrefix(&spec, "any") && allow_attr_type_any) {
|
||||
type = "any";
|
||||
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
|
||||
type = "type";
|
||||
AttrValue* allowed = attr->mutable_allowed_values();
|
||||
@ -633,13 +635,18 @@ OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() {
|
||||
allow_attr_type_any_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
|
||||
std::vector<string> errors = errors_;
|
||||
*op_reg_data = op_reg_data_;
|
||||
|
||||
OpDef* op_def = &op_reg_data->op_def;
|
||||
for (StringPiece attr : attrs_) {
|
||||
FinalizeAttr(attr, op_def, &errors);
|
||||
FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors);
|
||||
}
|
||||
for (StringPiece input : inputs_) {
|
||||
FinalizeInputOrOutput(input, false, op_def, &errors);
|
||||
|
@ -142,6 +142,10 @@ class OpDefBuilder {
|
||||
// python/framework/common_shapes.py
|
||||
OpDefBuilder& SetShapeFn(OpShapeInferenceFn fn);
|
||||
|
||||
// Allows the `<type>` in calls to `Attr()` to be "any".
|
||||
// This is used by PythonAPIWrapper for pass-through parameters.
|
||||
OpDefBuilder& AllowAttrTypeAny();
|
||||
|
||||
// Sets op_reg_data->op_def to the requested OpDef and
|
||||
// op_reg_data->shape_inference_fn to the requested shape inference function,
|
||||
// or returns an error.
|
||||
@ -168,6 +172,7 @@ class OpDefBuilder {
|
||||
std::vector<string> control_outputs_;
|
||||
std::string doc_;
|
||||
std::vector<string> errors_;
|
||||
bool allow_attr_type_any_ = false;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user