Add error payload in status.

PiperOrigin-RevId: 350501179
Change-Id: I97b328eccf138fac19c37d91b3cb59a9aa97359e
This commit is contained in:
A. Unique TensorFlower 2021-01-07 00:05:53 -08:00 committed by TensorFlower Gardener
parent 8b65e66f2f
commit 302135ad72
6 changed files with 39 additions and 4 deletions

View File

@ -57,6 +57,18 @@ namespace tensorflow {
} \
} while (0)
#define OP_REQUIRES_OK_OR_SET_PAYLOAD(CTX, PAYLOAD_KEY, PAYLOAD_VALUE, STATUS) \
do { \
if (!TF_PREDICT_TRUE(STATUS.ok())) { \
CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
if (!PAYLOAD_VALUE.empty()) { \
STATUS.SetPayload(PAYLOAD_KEY, PAYLOAD_VALUE); \
} \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, STATUS); \
return; \
} \
} while (0)
#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
do { \
if (!TF_PREDICT_TRUE(EXP)) { \

View File

@ -7,7 +7,8 @@ import "tensorflow/core/protobuf/error_codes.proto";
option cc_enable_arenas = true;
// Describes the result of a TPU compilation.
// Describes the result of a TPU compilation. This is also used as TPU
// compilation result status payload.
message CompilationResultProto {
// The error message, if any, returned during compilation.
error.Code status_code = 1;

View File

@ -56,6 +56,7 @@ cc_library(
":tpu_op_util",
":tpu_program_group_interface",
":tpu_util",
"//tensorflow/core/tpu:tpu_compile_interface",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
":tpu_util_hdrs",
"@com_google_absl//absl/strings",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
@ -40,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_compile_interface.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
@ -568,7 +570,21 @@ void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
done->store(true);
});
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
Status compile_status = ComputeInternal(ctx);
string status_payload;
// Construct payload if compile_status is not ok and there's no payload for
// compilation yet.
if (!compile_status.ok() &&
compile_status.GetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey)
.empty()) {
tpu::CompilationResultProto proto;
proto.set_status_code(compile_status.code());
proto.set_status_error_message(compile_status.error_message());
status_payload = proto.SerializeAsString();
}
OP_REQUIRES_OK_OR_SET_PAYLOAD(ctx,
TpuCompileInterface::kTpuCompileErrorPayloadKey,
status_payload, compile_status);
}
Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache(
@ -788,6 +804,8 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
}
SerializeToTString(proto, &output.scalar<tstring>()());
ctx->set_output(0, output);
status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey,
output.scalar<tstring>()());
}
if (status.ok()) {
@ -841,7 +859,7 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
}
}
}
return Status::OK();
return status;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -28,6 +28,9 @@ class TpuCompileInterface {
static bool RegisterImplementation(TpuCompileInterface* impl);
virtual uint64_t FingerprintString(absl::string_view str) = 0;
static inline constexpr char kTpuCompileErrorPayloadKey[] =
"https://www.tensorflow.org/tpu-compile-error";
};
#endif // TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_

View File

@ -300,7 +300,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
return strategy.experimental_local_results(
strategy.run(step_fn, args=(next(iterator),)))
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
with self.assertRaises(errors.InternalError):
logging.info(train_fn(iterator))
def test_computation_on_subset_cores(self, enable_packed_var):