From a5218435ecbf7e9694455d30302e9b066a512f89 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Wed, 15 Jan 2020 18:06:46 -0800 Subject: [PATCH] Fix XLA Status generation (using the 2-parameter construct will create a error status even if error::OK is passed in) PiperOrigin-RevId: 289976821 Change-Id: I0ef719d7373969db3334a01a018db3fd1ce0a1a9 --- .../python/tpu_driver/direct_tpu_driver.cc | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc index 0dc42e8f23c..6031c1f64b7 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/direct_tpu_driver.cc @@ -27,6 +27,15 @@ namespace tpu_driver { namespace { +xla::Status CreateXlaStatus(::TpuStatus* status) { + if (status->code == tensorflow::error::OK) { + return xla::Status::OK(); + } else { + return xla::Status(tensorflow::error::Code(status->code), + absl::StrFormat("%s", status->msg)); + } +} + constexpr char kDirectProtocol[] = "direct://"; ::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) { @@ -53,8 +62,7 @@ class DirectEvent : public Event { xla::Status Await() override { auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1); - auto ret = xla::Status(tensorflow::error::Code(tpu_status->code), - absl::StrFormat("%s", tpu_status->msg)); + auto ret = CreateXlaStatus(tpu_status); driver_fn_->TpuDriver_FreeStatus(tpu_status); return ret; } @@ -66,8 +74,7 @@ class DirectEvent : public Event { if (tpu_status_or == nullptr) { return absl::nullopt; } else { - auto ret = xla::Status(tensorflow::error::Code(tpu_status_or->code), - absl::StrFormat("%s", tpu_status_or->msg)); + auto ret = CreateXlaStatus(tpu_status_or); driver_fn_->TpuDriver_FreeStatus(tpu_status_or); return ret; } @@ -85,8 +92,7 @@ class DirectEvent : public Event { [](struct TpuStatus* status, void* additional_info) { auto callback_addr = static_cast*>(additional_info); - auto xla_status = xla::Status(tensorflow::error::Code(status->code), - absl::StrFormat("%s", status->msg)); + auto xla_status = CreateXlaStatus(status); (*callback_addr)(xla_status); delete callback_addr; }, @@ -142,10 +148,8 @@ class DirectCompiledProgramHandle : public CompiledProgramHandle { driver_fn_->TpuDriver_GetCompiledProgramShape(handle_); program_shape->ParseFromArray(shape->bytes, shape->size); - auto status = xla::Status(tensorflow::error::Code(shape->status->code), - absl::StrFormat("%s", shape->status->msg)); + auto status = CreateXlaStatus(shape->status); driver_fn_->TpuDriver_FreeCompiledProgramShape(shape); - return status; } @@ -196,8 +200,7 @@ class DirectTpuLinearizer : public TpuLinearizer { auto tpu_status = driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_); - auto status = xla::Status(tensorflow::error::Code(tpu_status->code), - absl::StrFormat("%s", tpu_status->msg)); + auto status = CreateXlaStatus(tpu_status); driver_fn_->TpuDriver_FreeStatus(tpu_status); free(shape_.bytes); return status; @@ -209,8 +212,7 @@ class DirectTpuLinearizer : public TpuLinearizer { auto tpu_status = driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_); - auto status = xla::Status(tensorflow::error::Code(tpu_status->code), - absl::StrFormat("%s", tpu_status->msg)); + auto status = CreateXlaStatus(tpu_status); driver_fn_->TpuDriver_FreeStatus(tpu_status); free(shape_.bytes); return status;