Make multiple changes to enable running an inference model on TFRT:
1. For fallback op, serialize all attributes to a string and under "_node_def_". 2. Allow reseting OpAttrs. 3. Chain all tfrt ops. PiperOrigin-RevId: 307928721 Change-Id: I5d62995ffe94185c642d191280a62891e0e055d5
This commit is contained in:
parent
a217053031
commit
b806942316
@ -111,6 +111,12 @@ class AttrBuilder {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
|
||||||
|
AddAttrIfNotPresent(attr_name, value);
|
||||||
|
cached_cache_key_ = absl::nullopt;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
// Retrieves the attribute value.
|
// Retrieves the attribute value.
|
||||||
// Note that Get() can involve a linear scan of all attributes with the same
|
// Note that Get() can involve a linear scan of all attributes with the same
|
||||||
// value type in this Node. This is not an issue, because Get is used rarely
|
// value type in this Node. This is not an issue, because Get is used rarely
|
||||||
|
@ -36,6 +36,12 @@ void EagerOperation::Clear() {
|
|||||||
ClearInferenceState();
|
ClearInferenceState();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EagerOperation::SetAttrValue(const char* attr_name,
|
||||||
|
const AttrValue& value) {
|
||||||
|
MutableAttrs()->Set(attr_name, value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
|
Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
|
||||||
size_t length) {
|
size_t length) {
|
||||||
MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||||
|
@ -74,6 +74,8 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
last_set_device_name_ = "\177"; // DEL (an invalid value)
|
last_set_device_name_ = "\177"; // DEL (an invalid value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status SetAttrValue(const char* attr_name, const AttrValue& value);
|
||||||
|
|
||||||
Status AddInput(AbstractTensorHandleInterface* input) override;
|
Status AddInput(AbstractTensorHandleInterface* input) override;
|
||||||
Status AddInputList(
|
Status AddInputList(
|
||||||
absl::Span<AbstractTensorHandleInterface*> inputs) override;
|
absl::Span<AbstractTensorHandleInterface*> inputs) override;
|
||||||
|
Loading…
Reference in New Issue
Block a user