diff --git a/tensorflow/core/framework/op_requires.h b/tensorflow/core/framework/op_requires.h index b7cf6e859fb..d186df337d7 100644 --- a/tensorflow/core/framework/op_requires.h +++ b/tensorflow/core/framework/op_requires.h @@ -57,6 +57,18 @@ namespace tensorflow { } \ } while (0) +#define OP_REQUIRES_OK_OR_SET_PAYLOAD(CTX, PAYLOAD_KEY, PAYLOAD_VALUE, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(STATUS.ok())) { \ + CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \ + if (!PAYLOAD_VALUE.empty()) { \ + STATUS.SetPayload(PAYLOAD_KEY, PAYLOAD_VALUE); \ + } \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, STATUS); \ + return; \ + } \ + } while (0) + #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ do { \ if (!TF_PREDICT_TRUE(EXP)) { \ diff --git a/tensorflow/core/protobuf/tpu/compilation_result.proto b/tensorflow/core/protobuf/tpu/compilation_result.proto index 5dd74dead49..d88bddf9173 100644 --- a/tensorflow/core/protobuf/tpu/compilation_result.proto +++ b/tensorflow/core/protobuf/tpu/compilation_result.proto @@ -7,7 +7,8 @@ import "tensorflow/core/protobuf/error_codes.proto"; option cc_enable_arenas = true; -// Describes the result of a TPU compilation. +// Describes the result of a TPU compilation. This is also used as TPU +// compilation result status payload. message CompilationResultProto { // The error message, if any, returned during compilation. error.Code status_code = 1; diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 6461db0cfd0..b187ca31970 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -56,6 +56,7 @@ cc_library( ":tpu_op_util", ":tpu_program_group_interface", ":tpu_util", + "//tensorflow/core/tpu:tpu_compile_interface", "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", ":tpu_util_hdrs", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index fa98ab763ca..55d34622912 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/tpu/compilation_result.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_compile_interface.h" #include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/tpu/tpu_ops_c_api.h" @@ -568,7 +570,21 @@ void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) { done->store(true); }); - OP_REQUIRES_OK(ctx, ComputeInternal(ctx)); + Status compile_status = ComputeInternal(ctx); + string status_payload; + // Construct payload if compile_status is not ok and there's no payload for + // compilation yet. + if (!compile_status.ok() && + compile_status.GetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey) + .empty()) { + tpu::CompilationResultProto proto; + proto.set_status_code(compile_status.code()); + proto.set_status_error_message(compile_status.error_message()); + status_payload = proto.SerializeAsString(); + } + OP_REQUIRES_OK_OR_SET_PAYLOAD(ctx, + TpuCompileInterface::kTpuCompileErrorPayloadKey, + status_payload, compile_status); } Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache( @@ -788,6 +804,8 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) { } SerializeToTString(proto, &output.scalar()()); ctx->set_output(0, output); + status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey, + output.scalar()()); } if (status.ok()) { @@ -841,7 +859,7 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) { } } } - return Status::OK(); + return status; } } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_compile_interface.h b/tensorflow/core/tpu/tpu_compile_interface.h index 7e7b1f8315a..8f73b2dcacd 100644 --- a/tensorflow/core/tpu/tpu_compile_interface.h +++ b/tensorflow/core/tpu/tpu_compile_interface.h @@ -28,6 +28,9 @@ class TpuCompileInterface { static bool RegisterImplementation(TpuCompileInterface* impl); virtual uint64_t FingerprintString(absl::string_view str) = 0; + + static inline constexpr char kTpuCompileErrorPayloadKey[] = + "https://www.tensorflow.org/tpu-compile-error"; }; #endif // TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_ diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 9f5fdb04d1e..dcb9cce1f66 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -300,7 +300,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): return strategy.experimental_local_results( strategy.run(step_fn, args=(next(iterator),))) - with self.assertRaisesRegex(errors.InternalError, "Compilation failure"): + with self.assertRaises(errors.InternalError): logging.info(train_fn(iterator)) def test_computation_on_subset_cores(self, enable_packed_var):