Add StatefulErrorReporter and make error_reporter() public from Interpreter API.

PiperOrigin-RevId: 335505537
Change-Id: Ic48a841ca1ace3d664969baa12a9b2e7681a3bb3
This commit is contained in:
A. Unique TensorFlower 2020-10-05 14:53:09 -07:00 committed by TensorFlower Gardener
parent b34984f33a
commit 1ad94c17b4
5 changed files with 49 additions and 7 deletions

View File

@ -669,6 +669,12 @@ cc_library(
hdrs = ["core/macros.h"],
)
cc_library(
name = "stateful_error_reporter",
hdrs = ["stateful_error_reporter.h"],
deps = ["//tensorflow/lite/core/api"],
)
# Shared lib target for convenience, pulls in the core runtime and builtin ops.
# Note: This target is not yet finalized, and the exact set of exported (C/C++)
# APIs is subject to change. The output library name is platform dependent:

View File

@ -576,6 +576,11 @@ class Interpreter {
const Subgraph& primary_subgraph() const {
return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
}
/// WARNING: Experimental interface, subject to change
// Get the error reporter associated with this interpreter.
ErrorReporter* error_reporter() const { return error_reporter_; }
#endif // DOXYGEN_SKIP
private:
@ -602,9 +607,6 @@ class Interpreter {
// Returns true if cancellation function returns true.
bool IsCancelled();
// Get the error reporter associated with this interpreter.
ErrorReporter* error_reporter() { return error_reporter_; }
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.

View File

@ -46,7 +46,7 @@ cc_library(
srcs = ["python_error_reporter.cc"],
hdrs = ["python_error_reporter.h"],
deps = [
"//tensorflow/lite/core/api",
"//tensorflow/lite:stateful_error_reporter",
"//third_party/python_runtime:headers", # buildcleaner: keep
],
)

View File

@ -21,12 +21,12 @@ limitations under the License.
#include <sstream>
#include <string>
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/stateful_error_reporter.h"
namespace tflite {
namespace interpreter_wrapper {
class PythonErrorReporter : public tflite::ErrorReporter {
class PythonErrorReporter : public tflite::StatefulErrorReporter {
public:
PythonErrorReporter() {}
@ -38,7 +38,7 @@ class PythonErrorReporter : public tflite::ErrorReporter {
PyObject* exception();
// Gets the last error message and clears the buffer.
std::string message();
std::string message() override;
private:
std::stringstream buffer_;

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
#define TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
#include <string>
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite {
// Similar to tflite::ErrorReporter, except that it allows callers to get the
// last error message.
class StatefulErrorReporter : public ErrorReporter {
public:
// Returns last error message. Returns empty string if no error is reported.
virtual std::string message() = 0;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_