Add Status payload API

PiperOrigin-RevId: 344277932
Change-Id: I7aef1022c119dbaced5ed986c23625775883e224
This commit is contained in:
A. Unique TensorFlower 2020-11-25 10:34:59 -08:00 committed by TensorFlower Gardener
parent aac6f14020
commit 9270fdcada
3 changed files with 70 additions and 0 deletions

View File

@ -175,6 +175,33 @@ TEST(StatusGroup, AggregateWithMultipleErrorStatus) {
aborted.error_message()));
}
TEST(Status, InvalidPayloadGetsIgnored) {
Status s = Status();
s.SetPayload("Invalid", "Invalid Val");
ASSERT_EQ(s.GetPayload("Invalid"), tensorflow::StringPiece());
bool is_err_erased = s.ErasePayload("Invalid");
ASSERT_EQ(is_err_erased, false);
}
TEST(Status, SetPayloadSetsOrUpdatesIt) {
Status s(error::INTERNAL, "Error message");
s.SetPayload("Error key", "Original");
ASSERT_EQ(s.GetPayload("Error key"), tensorflow::StringPiece("Original"));
s.SetPayload("Error key", "Updated");
ASSERT_EQ(s.GetPayload("Error key"), tensorflow::StringPiece("Updated"));
}
TEST(Status, ErasePayloadRemovesIt) {
Status s(error::INTERNAL, "Error message");
s.SetPayload("Error key", "Original");
bool is_err_erased = s.ErasePayload("Error key");
ASSERT_EQ(is_err_erased, true);
is_err_erased = s.ErasePayload("Error key");
ASSERT_EQ(is_err_erased, false);
ASSERT_EQ(s.GetPayload("Error key"), tensorflow::StringPiece());
}
static void BM_TF_CHECK_OK(int iters) {
tensorflow::Status s =
(iters < 0) ? errors::InvalidArgument("Invalid") : Status::OK();

View File

@ -200,6 +200,28 @@ void Status::IgnoreError() const {
// no-op
}
void Status::SetPayload(tensorflow::StringPiece type_url,
tensorflow::StringPiece payload) {
if (ok()) return;
state_->payloads[std::string(type_url)] = std::string(payload);
}
tensorflow::StringPiece Status::GetPayload(
tensorflow::StringPiece type_url) const {
if (ok()) return tensorflow::StringPiece();
auto payload_iter = state_->payloads.find(std::string(type_url));
if (payload_iter == state_->payloads.end()) return tensorflow::StringPiece();
return tensorflow::StringPiece(payload_iter->second);
}
bool Status::ErasePayload(tensorflow::StringPiece type_url) {
if (ok()) return false;
auto payload_iter = state_->payloads.find(std::string(type_url));
if (payload_iter == state_->payloads.end()) return false;
state_->payloads.erase(payload_iter);
return true;
}
std::ostream& operator<<(std::ostream& os, const Status& x) {
os << x.ToString();
return os;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <iosfwd>
#include <memory>
#include <string>
#include <unordered_map>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -115,6 +116,24 @@ class Status {
// the floor.
void IgnoreError() const;
// The Payload-related APIs are cloned from absl::Status.
//
// Returns the payload of a status given its unique `type_url` key, if
// present. Returns an empty StringPiece if the status is ok, or if the key is
// not present.
tensorflow::StringPiece GetPayload(tensorflow::StringPiece type_url) const;
// Sets the payload for a non-ok status using a `type_url` key, overwriting
// any existing payload for that `type_url`.
//
// NOTE: This function does nothing if the Status is ok.
void SetPayload(tensorflow::StringPiece type_url,
tensorflow::StringPiece payload);
// Erases the payload corresponding to the `type_url` key. Returns `true` if
// the payload was present.
bool ErasePayload(tensorflow::StringPiece type_url);
private:
static const std::string& empty_string();
static const std::vector<StackFrame>& empty_stack_trace();
@ -122,7 +141,9 @@ class Status {
tensorflow::error::Code code;
std::string msg;
std::vector<StackFrame> stack_trace;
std::unordered_map<std::string, std::string> payloads;
};
// OK status has a `NULL` state_. Otherwise, `state_` points to
// a `State` structure containing the error code and message(s)
std::unique_ptr<State> state_;