Fix XLA Status generation (using the 2-parameter construct will create a error status even if error::OK is passed in)
PiperOrigin-RevId: 289976821 Change-Id: I0ef719d7373969db3334a01a018db3fd1ce0a1a9
This commit is contained in:
parent
52b8ba5463
commit
a5218435ec
@ -27,6 +27,15 @@
|
|||||||
namespace tpu_driver {
|
namespace tpu_driver {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
xla::Status CreateXlaStatus(::TpuStatus* status) {
|
||||||
|
if (status->code == tensorflow::error::OK) {
|
||||||
|
return xla::Status::OK();
|
||||||
|
} else {
|
||||||
|
return xla::Status(tensorflow::error::Code(status->code),
|
||||||
|
absl::StrFormat("%s", status->msg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
constexpr char kDirectProtocol[] = "direct://";
|
constexpr char kDirectProtocol[] = "direct://";
|
||||||
|
|
||||||
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
|
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
|
||||||
@ -53,8 +62,7 @@ class DirectEvent : public Event {
|
|||||||
|
|
||||||
xla::Status Await() override {
|
xla::Status Await() override {
|
||||||
auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
|
auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
|
||||||
auto ret = xla::Status(tensorflow::error::Code(tpu_status->code),
|
auto ret = CreateXlaStatus(tpu_status);
|
||||||
absl::StrFormat("%s", tpu_status->msg));
|
|
||||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -66,8 +74,7 @@ class DirectEvent : public Event {
|
|||||||
if (tpu_status_or == nullptr) {
|
if (tpu_status_or == nullptr) {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
} else {
|
} else {
|
||||||
auto ret = xla::Status(tensorflow::error::Code(tpu_status_or->code),
|
auto ret = CreateXlaStatus(tpu_status_or);
|
||||||
absl::StrFormat("%s", tpu_status_or->msg));
|
|
||||||
driver_fn_->TpuDriver_FreeStatus(tpu_status_or);
|
driver_fn_->TpuDriver_FreeStatus(tpu_status_or);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -85,8 +92,7 @@ class DirectEvent : public Event {
|
|||||||
[](struct TpuStatus* status, void* additional_info) {
|
[](struct TpuStatus* status, void* additional_info) {
|
||||||
auto callback_addr =
|
auto callback_addr =
|
||||||
static_cast<std::function<void(xla::Status)>*>(additional_info);
|
static_cast<std::function<void(xla::Status)>*>(additional_info);
|
||||||
auto xla_status = xla::Status(tensorflow::error::Code(status->code),
|
auto xla_status = CreateXlaStatus(status);
|
||||||
absl::StrFormat("%s", status->msg));
|
|
||||||
(*callback_addr)(xla_status);
|
(*callback_addr)(xla_status);
|
||||||
delete callback_addr;
|
delete callback_addr;
|
||||||
},
|
},
|
||||||
@ -142,10 +148,8 @@ class DirectCompiledProgramHandle : public CompiledProgramHandle {
|
|||||||
driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
|
driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
|
||||||
program_shape->ParseFromArray(shape->bytes, shape->size);
|
program_shape->ParseFromArray(shape->bytes, shape->size);
|
||||||
|
|
||||||
auto status = xla::Status(tensorflow::error::Code(shape->status->code),
|
auto status = CreateXlaStatus(shape->status);
|
||||||
absl::StrFormat("%s", shape->status->msg));
|
|
||||||
driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
|
driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -196,8 +200,7 @@ class DirectTpuLinearizer : public TpuLinearizer {
|
|||||||
|
|
||||||
auto tpu_status =
|
auto tpu_status =
|
||||||
driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
|
driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
|
||||||
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
auto status = CreateXlaStatus(tpu_status);
|
||||||
absl::StrFormat("%s", tpu_status->msg));
|
|
||||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||||
free(shape_.bytes);
|
free(shape_.bytes);
|
||||||
return status;
|
return status;
|
||||||
@ -209,8 +212,7 @@ class DirectTpuLinearizer : public TpuLinearizer {
|
|||||||
|
|
||||||
auto tpu_status =
|
auto tpu_status =
|
||||||
driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
|
driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
|
||||||
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
auto status = CreateXlaStatus(tpu_status);
|
||||||
absl::StrFormat("%s", tpu_status->msg));
|
|
||||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||||
free(shape_.bytes);
|
free(shape_.bytes);
|
||||||
return status;
|
return status;
|
||||||
|
Loading…
Reference in New Issue
Block a user