Fix XLA Status generation (using the 2-parameter construct will create a error status even if error::OK is passed in)

PiperOrigin-RevId: 289976821
Change-Id: I0ef719d7373969db3334a01a018db3fd1ce0a1a9
This commit is contained in:
Frank Chen 2020-01-15 18:06:46 -08:00 committed by TensorFlower Gardener
parent 52b8ba5463
commit a5218435ec

View File

@ -27,6 +27,15 @@
namespace tpu_driver {
namespace {
xla::Status CreateXlaStatus(::TpuStatus* status) {
if (status->code == tensorflow::error::OK) {
return xla::Status::OK();
} else {
return xla::Status(tensorflow::error::Code(status->code),
absl::StrFormat("%s", status->msg));
}
}
constexpr char kDirectProtocol[] = "direct://";
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
@ -53,8 +62,7 @@ class DirectEvent : public Event {
xla::Status Await() override {
auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
auto ret = xla::Status(tensorflow::error::Code(tpu_status->code),
absl::StrFormat("%s", tpu_status->msg));
auto ret = CreateXlaStatus(tpu_status);
driver_fn_->TpuDriver_FreeStatus(tpu_status);
return ret;
}
@ -66,8 +74,7 @@ class DirectEvent : public Event {
if (tpu_status_or == nullptr) {
return absl::nullopt;
} else {
auto ret = xla::Status(tensorflow::error::Code(tpu_status_or->code),
absl::StrFormat("%s", tpu_status_or->msg));
auto ret = CreateXlaStatus(tpu_status_or);
driver_fn_->TpuDriver_FreeStatus(tpu_status_or);
return ret;
}
@ -85,8 +92,7 @@ class DirectEvent : public Event {
[](struct TpuStatus* status, void* additional_info) {
auto callback_addr =
static_cast<std::function<void(xla::Status)>*>(additional_info);
auto xla_status = xla::Status(tensorflow::error::Code(status->code),
absl::StrFormat("%s", status->msg));
auto xla_status = CreateXlaStatus(status);
(*callback_addr)(xla_status);
delete callback_addr;
},
@ -142,10 +148,8 @@ class DirectCompiledProgramHandle : public CompiledProgramHandle {
driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
program_shape->ParseFromArray(shape->bytes, shape->size);
auto status = xla::Status(tensorflow::error::Code(shape->status->code),
absl::StrFormat("%s", shape->status->msg));
auto status = CreateXlaStatus(shape->status);
driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
return status;
}
@ -196,8 +200,7 @@ class DirectTpuLinearizer : public TpuLinearizer {
auto tpu_status =
driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
absl::StrFormat("%s", tpu_status->msg));
auto status = CreateXlaStatus(tpu_status);
driver_fn_->TpuDriver_FreeStatus(tpu_status);
free(shape_.bytes);
return status;
@ -209,8 +212,7 @@ class DirectTpuLinearizer : public TpuLinearizer {
auto tpu_status =
driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
absl::StrFormat("%s", tpu_status->msg));
auto status = CreateXlaStatus(tpu_status);
driver_fn_->TpuDriver_FreeStatus(tpu_status);
free(shape_.bytes);
return status;