Add StatefulErrorReporter and make error_reporter() public from Interpreter API.
PiperOrigin-RevId: 335505537 Change-Id: Ic48a841ca1ace3d664969baa12a9b2e7681a3bb3
This commit is contained in:
parent
b34984f33a
commit
1ad94c17b4
tensorflow/lite
@ -669,6 +669,12 @@ cc_library(
|
||||
hdrs = ["core/macros.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stateful_error_reporter",
|
||||
hdrs = ["stateful_error_reporter.h"],
|
||||
deps = ["//tensorflow/lite/core/api"],
|
||||
)
|
||||
|
||||
# Shared lib target for convenience, pulls in the core runtime and builtin ops.
|
||||
# Note: This target is not yet finalized, and the exact set of exported (C/C++)
|
||||
# APIs is subject to change. The output library name is platform dependent:
|
||||
|
@ -576,6 +576,11 @@ class Interpreter {
|
||||
const Subgraph& primary_subgraph() const {
|
||||
return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
|
||||
}
|
||||
|
||||
/// WARNING: Experimental interface, subject to change
|
||||
// Get the error reporter associated with this interpreter.
|
||||
ErrorReporter* error_reporter() const { return error_reporter_; }
|
||||
|
||||
#endif // DOXYGEN_SKIP
|
||||
|
||||
private:
|
||||
@ -602,9 +607,6 @@ class Interpreter {
|
||||
// Returns true if cancellation function returns true.
|
||||
bool IsCancelled();
|
||||
|
||||
// Get the error reporter associated with this interpreter.
|
||||
ErrorReporter* error_reporter() { return error_reporter_; }
|
||||
|
||||
// A pure C data structure used to communicate with the pure C plugin
|
||||
// interface. To avoid copying tensor metadata, this is also the definitive
|
||||
// structure to store tensors.
|
||||
|
@ -46,7 +46,7 @@ cc_library(
|
||||
srcs = ["python_error_reporter.cc"],
|
||||
hdrs = ["python_error_reporter.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite:stateful_error_reporter",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
],
|
||||
)
|
||||
|
@ -21,12 +21,12 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/stateful_error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace interpreter_wrapper {
|
||||
|
||||
class PythonErrorReporter : public tflite::ErrorReporter {
|
||||
class PythonErrorReporter : public tflite::StatefulErrorReporter {
|
||||
public:
|
||||
PythonErrorReporter() {}
|
||||
|
||||
@ -38,7 +38,7 @@ class PythonErrorReporter : public tflite::ErrorReporter {
|
||||
PyObject* exception();
|
||||
|
||||
// Gets the last error message and clears the buffer.
|
||||
std::string message();
|
||||
std::string message() override;
|
||||
|
||||
private:
|
||||
std::stringstream buffer_;
|
||||
|
34
tensorflow/lite/stateful_error_reporter.h
Normal file
34
tensorflow/lite/stateful_error_reporter.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
|
||||
#define TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Similar to tflite::ErrorReporter, except that it allows callers to get the
|
||||
// last error message.
|
||||
class StatefulErrorReporter : public ErrorReporter {
|
||||
public:
|
||||
// Returns last error message. Returns empty string if no error is reported.
|
||||
virtual std::string message() = 0;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
|
Loading…
Reference in New Issue
Block a user