diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h index 1a871b01a4d..c53286dd999 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.h +++ b/tensorflow/core/common_runtime/eager/attr_builder.h @@ -111,6 +111,12 @@ class AttrBuilder { return *this; } + AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) { + AddAttrIfNotPresent(attr_name, value); + cached_cache_key_ = absl::nullopt; + return *this; + } + // Retrieves the attribute value. // Note that Get() can involve a linear scan of all attributes with the same // value type in this Node. This is not an issue, because Get is used rarely diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 44d2fe4f744..090bfef46bd 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -36,6 +36,12 @@ void EagerOperation::Clear() { ClearInferenceState(); } +Status EagerOperation::SetAttrValue(const char* attr_name, + const AttrValue& value) { + MutableAttrs()->Set(attr_name, value); + return Status::OK(); +} + Status EagerOperation::SetAttrString(const char* attr_name, const char* data, size_t length) { MutableAttrs()->Set(attr_name, StringPiece(data, length)); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index b92a144a796..14268ef2630 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -74,6 +74,8 @@ class EagerOperation : public AbstractOperationInterface { last_set_device_name_ = "\177"; // DEL (an invalid value) } + Status SetAttrValue(const char* attr_name, const AttrValue& value); + Status AddInput(AbstractTensorHandleInterface* input) override; Status AddInputList( absl::Span inputs) override;