Add Status
payload API
PiperOrigin-RevId: 344277932 Change-Id: I7aef1022c119dbaced5ed986c23625775883e224
This commit is contained in:
parent
aac6f14020
commit
9270fdcada
@ -175,6 +175,33 @@ TEST(StatusGroup, AggregateWithMultipleErrorStatus) {
|
||||
aborted.error_message()));
|
||||
}
|
||||
|
||||
TEST(Status, InvalidPayloadGetsIgnored) {
|
||||
Status s = Status();
|
||||
s.SetPayload("Invalid", "Invalid Val");
|
||||
ASSERT_EQ(s.GetPayload("Invalid"), tensorflow::StringPiece());
|
||||
bool is_err_erased = s.ErasePayload("Invalid");
|
||||
ASSERT_EQ(is_err_erased, false);
|
||||
}
|
||||
|
||||
TEST(Status, SetPayloadSetsOrUpdatesIt) {
|
||||
Status s(error::INTERNAL, "Error message");
|
||||
s.SetPayload("Error key", "Original");
|
||||
ASSERT_EQ(s.GetPayload("Error key"), tensorflow::StringPiece("Original"));
|
||||
s.SetPayload("Error key", "Updated");
|
||||
ASSERT_EQ(s.GetPayload("Error key"), tensorflow::StringPiece("Updated"));
|
||||
}
|
||||
|
||||
TEST(Status, ErasePayloadRemovesIt) {
|
||||
Status s(error::INTERNAL, "Error message");
|
||||
s.SetPayload("Error key", "Original");
|
||||
|
||||
bool is_err_erased = s.ErasePayload("Error key");
|
||||
ASSERT_EQ(is_err_erased, true);
|
||||
is_err_erased = s.ErasePayload("Error key");
|
||||
ASSERT_EQ(is_err_erased, false);
|
||||
ASSERT_EQ(s.GetPayload("Error key"), tensorflow::StringPiece());
|
||||
}
|
||||
|
||||
static void BM_TF_CHECK_OK(int iters) {
|
||||
tensorflow::Status s =
|
||||
(iters < 0) ? errors::InvalidArgument("Invalid") : Status::OK();
|
||||
|
@ -200,6 +200,28 @@ void Status::IgnoreError() const {
|
||||
// no-op
|
||||
}
|
||||
|
||||
void Status::SetPayload(tensorflow::StringPiece type_url,
|
||||
tensorflow::StringPiece payload) {
|
||||
if (ok()) return;
|
||||
state_->payloads[std::string(type_url)] = std::string(payload);
|
||||
}
|
||||
|
||||
tensorflow::StringPiece Status::GetPayload(
|
||||
tensorflow::StringPiece type_url) const {
|
||||
if (ok()) return tensorflow::StringPiece();
|
||||
auto payload_iter = state_->payloads.find(std::string(type_url));
|
||||
if (payload_iter == state_->payloads.end()) return tensorflow::StringPiece();
|
||||
return tensorflow::StringPiece(payload_iter->second);
|
||||
}
|
||||
|
||||
bool Status::ErasePayload(tensorflow::StringPiece type_url) {
|
||||
if (ok()) return false;
|
||||
auto payload_iter = state_->payloads.find(std::string(type_url));
|
||||
if (payload_iter == state_->payloads.end()) return false;
|
||||
state_->payloads.erase(payload_iter);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Status& x) {
|
||||
os << x.ToString();
|
||||
return os;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -115,6 +116,24 @@ class Status {
|
||||
// the floor.
|
||||
void IgnoreError() const;
|
||||
|
||||
// The Payload-related APIs are cloned from absl::Status.
|
||||
//
|
||||
// Returns the payload of a status given its unique `type_url` key, if
|
||||
// present. Returns an empty StringPiece if the status is ok, or if the key is
|
||||
// not present.
|
||||
tensorflow::StringPiece GetPayload(tensorflow::StringPiece type_url) const;
|
||||
|
||||
// Sets the payload for a non-ok status using a `type_url` key, overwriting
|
||||
// any existing payload for that `type_url`.
|
||||
//
|
||||
// NOTE: This function does nothing if the Status is ok.
|
||||
void SetPayload(tensorflow::StringPiece type_url,
|
||||
tensorflow::StringPiece payload);
|
||||
|
||||
// Erases the payload corresponding to the `type_url` key. Returns `true` if
|
||||
// the payload was present.
|
||||
bool ErasePayload(tensorflow::StringPiece type_url);
|
||||
|
||||
private:
|
||||
static const std::string& empty_string();
|
||||
static const std::vector<StackFrame>& empty_stack_trace();
|
||||
@ -122,7 +141,9 @@ class Status {
|
||||
tensorflow::error::Code code;
|
||||
std::string msg;
|
||||
std::vector<StackFrame> stack_trace;
|
||||
std::unordered_map<std::string, std::string> payloads;
|
||||
};
|
||||
|
||||
// OK status has a `NULL` state_. Otherwise, `state_` points to
|
||||
// a `State` structure containing the error code and message(s)
|
||||
std::unique_ptr<State> state_;
|
||||
|
Loading…
Reference in New Issue
Block a user