Fix XLA Status generation (using the 2-parameter construct will create a error status even if error::OK is passed in)
PiperOrigin-RevId: 289976821 Change-Id: I0ef719d7373969db3334a01a018db3fd1ce0a1a9
This commit is contained in:
parent
52b8ba5463
commit
a5218435ec
@ -27,6 +27,15 @@
|
||||
namespace tpu_driver {
|
||||
namespace {
|
||||
|
||||
xla::Status CreateXlaStatus(::TpuStatus* status) {
|
||||
if (status->code == tensorflow::error::OK) {
|
||||
return xla::Status::OK();
|
||||
} else {
|
||||
return xla::Status(tensorflow::error::Code(status->code),
|
||||
absl::StrFormat("%s", status->msg));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr char kDirectProtocol[] = "direct://";
|
||||
|
||||
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
|
||||
@ -53,8 +62,7 @@ class DirectEvent : public Event {
|
||||
|
||||
xla::Status Await() override {
|
||||
auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
|
||||
auto ret = xla::Status(tensorflow::error::Code(tpu_status->code),
|
||||
absl::StrFormat("%s", tpu_status->msg));
|
||||
auto ret = CreateXlaStatus(tpu_status);
|
||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||
return ret;
|
||||
}
|
||||
@ -66,8 +74,7 @@ class DirectEvent : public Event {
|
||||
if (tpu_status_or == nullptr) {
|
||||
return absl::nullopt;
|
||||
} else {
|
||||
auto ret = xla::Status(tensorflow::error::Code(tpu_status_or->code),
|
||||
absl::StrFormat("%s", tpu_status_or->msg));
|
||||
auto ret = CreateXlaStatus(tpu_status_or);
|
||||
driver_fn_->TpuDriver_FreeStatus(tpu_status_or);
|
||||
return ret;
|
||||
}
|
||||
@ -85,8 +92,7 @@ class DirectEvent : public Event {
|
||||
[](struct TpuStatus* status, void* additional_info) {
|
||||
auto callback_addr =
|
||||
static_cast<std::function<void(xla::Status)>*>(additional_info);
|
||||
auto xla_status = xla::Status(tensorflow::error::Code(status->code),
|
||||
absl::StrFormat("%s", status->msg));
|
||||
auto xla_status = CreateXlaStatus(status);
|
||||
(*callback_addr)(xla_status);
|
||||
delete callback_addr;
|
||||
},
|
||||
@ -142,10 +148,8 @@ class DirectCompiledProgramHandle : public CompiledProgramHandle {
|
||||
driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
|
||||
program_shape->ParseFromArray(shape->bytes, shape->size);
|
||||
|
||||
auto status = xla::Status(tensorflow::error::Code(shape->status->code),
|
||||
absl::StrFormat("%s", shape->status->msg));
|
||||
auto status = CreateXlaStatus(shape->status);
|
||||
driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
@ -196,8 +200,7 @@ class DirectTpuLinearizer : public TpuLinearizer {
|
||||
|
||||
auto tpu_status =
|
||||
driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
|
||||
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
||||
absl::StrFormat("%s", tpu_status->msg));
|
||||
auto status = CreateXlaStatus(tpu_status);
|
||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||
free(shape_.bytes);
|
||||
return status;
|
||||
@ -209,8 +212,7 @@ class DirectTpuLinearizer : public TpuLinearizer {
|
||||
|
||||
auto tpu_status =
|
||||
driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
|
||||
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
||||
absl::StrFormat("%s", tpu_status->msg));
|
||||
auto status = CreateXlaStatus(tpu_status);
|
||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||
free(shape_.bytes);
|
||||
return status;
|
||||
|
Loading…
Reference in New Issue
Block a user