Add error payload in status.
PiperOrigin-RevId: 350501179 Change-Id: I97b328eccf138fac19c37d91b3cb59a9aa97359e
This commit is contained in:
parent
8b65e66f2f
commit
302135ad72
@ -57,6 +57,18 @@ namespace tensorflow {
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define OP_REQUIRES_OK_OR_SET_PAYLOAD(CTX, PAYLOAD_KEY, PAYLOAD_VALUE, STATUS) \
|
||||
do { \
|
||||
if (!TF_PREDICT_TRUE(STATUS.ok())) { \
|
||||
CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
|
||||
if (!PAYLOAD_VALUE.empty()) { \
|
||||
STATUS.SetPayload(PAYLOAD_KEY, PAYLOAD_VALUE); \
|
||||
} \
|
||||
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, STATUS); \
|
||||
return; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
|
||||
do { \
|
||||
if (!TF_PREDICT_TRUE(EXP)) { \
|
||||
|
@ -7,7 +7,8 @@ import "tensorflow/core/protobuf/error_codes.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Describes the result of a TPU compilation.
|
||||
// Describes the result of a TPU compilation. This is also used as TPU
|
||||
// compilation result status payload.
|
||||
message CompilationResultProto {
|
||||
// The error message, if any, returned during compilation.
|
||||
error.Code status_code = 1;
|
||||
|
@ -56,6 +56,7 @@ cc_library(
|
||||
":tpu_op_util",
|
||||
":tpu_program_group_interface",
|
||||
":tpu_util",
|
||||
"//tensorflow/core/tpu:tpu_compile_interface",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
":tpu_util_hdrs",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/metrics.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
|
||||
@ -40,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_compile_interface.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
@ -568,7 +570,21 @@ void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
|
||||
done->store(true);
|
||||
});
|
||||
|
||||
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
|
||||
Status compile_status = ComputeInternal(ctx);
|
||||
string status_payload;
|
||||
// Construct payload if compile_status is not ok and there's no payload for
|
||||
// compilation yet.
|
||||
if (!compile_status.ok() &&
|
||||
compile_status.GetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey)
|
||||
.empty()) {
|
||||
tpu::CompilationResultProto proto;
|
||||
proto.set_status_code(compile_status.code());
|
||||
proto.set_status_error_message(compile_status.error_message());
|
||||
status_payload = proto.SerializeAsString();
|
||||
}
|
||||
OP_REQUIRES_OK_OR_SET_PAYLOAD(ctx,
|
||||
TpuCompileInterface::kTpuCompileErrorPayloadKey,
|
||||
status_payload, compile_status);
|
||||
}
|
||||
|
||||
Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache(
|
||||
@ -788,6 +804,8 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
||||
}
|
||||
SerializeToTString(proto, &output.scalar<tstring>()());
|
||||
ctx->set_output(0, output);
|
||||
status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey,
|
||||
output.scalar<tstring>()());
|
||||
}
|
||||
|
||||
if (status.ok()) {
|
||||
@ -841,7 +859,7 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
return status;
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,9 @@ class TpuCompileInterface {
|
||||
static bool RegisterImplementation(TpuCompileInterface* impl);
|
||||
|
||||
virtual uint64_t FingerprintString(absl::string_view str) = 0;
|
||||
|
||||
static inline constexpr char kTpuCompileErrorPayloadKey[] =
|
||||
"https://www.tensorflow.org/tpu-compile-error";
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_
|
||||
|
@ -300,7 +300,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
return strategy.experimental_local_results(
|
||||
strategy.run(step_fn, args=(next(iterator),)))
|
||||
|
||||
with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
|
||||
with self.assertRaises(errors.InternalError):
|
||||
logging.info(train_fn(iterator))
|
||||
|
||||
def test_computation_on_subset_cores(self, enable_packed_var):
|
||||
|
Loading…
Reference in New Issue
Block a user