Make multiple changes to enable running an inference model on TFRT:

1. For fallback op, serialize all attributes to a string and under "_node_def_".
2. Allow reseting OpAttrs.
3. Chain all tfrt ops.

PiperOrigin-RevId: 307928721
Change-Id: I5d62995ffe94185c642d191280a62891e0e055d5
This commit is contained in:
Xiao Yu 2020-04-22 16:35:39 -07:00 committed by TensorFlower Gardener
parent a217053031
commit b806942316
3 changed files with 14 additions and 0 deletions

View File

@ -111,6 +111,12 @@ class AttrBuilder {
return *this; return *this;
} }
AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
AddAttrIfNotPresent(attr_name, value);
cached_cache_key_ = absl::nullopt;
return *this;
}
// Retrieves the attribute value. // Retrieves the attribute value.
// Note that Get() can involve a linear scan of all attributes with the same // Note that Get() can involve a linear scan of all attributes with the same
// value type in this Node. This is not an issue, because Get is used rarely // value type in this Node. This is not an issue, because Get is used rarely

View File

@ -36,6 +36,12 @@ void EagerOperation::Clear() {
ClearInferenceState(); ClearInferenceState();
} }
Status EagerOperation::SetAttrValue(const char* attr_name,
const AttrValue& value) {
MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status EagerOperation::SetAttrString(const char* attr_name, const char* data, Status EagerOperation::SetAttrString(const char* attr_name, const char* data,
size_t length) { size_t length) {
MutableAttrs()->Set(attr_name, StringPiece(data, length)); MutableAttrs()->Set(attr_name, StringPiece(data, length));

View File

@ -74,6 +74,8 @@ class EagerOperation : public AbstractOperationInterface {
last_set_device_name_ = "\177"; // DEL (an invalid value) last_set_device_name_ = "\177"; // DEL (an invalid value)
} }
Status SetAttrValue(const char* attr_name, const AttrValue& value);
Status AddInput(AbstractTensorHandleInterface* input) override; Status AddInput(AbstractTensorHandleInterface* input) override;
Status AddInputList( Status AddInputList(
absl::Span<AbstractTensorHandleInterface*> inputs) override; absl::Span<AbstractTensorHandleInterface*> inputs) override;