diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 9b78aa4df1f..f372e69df00 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -718,6 +718,7 @@ def tflite_experimental_runtime_linkopts(if_eager = [], if_non_eager = [], if_no if_eager = [ # "//tensorflow/lite/experimental/tf_runtime:eager_interpreter", # "//tensorflow/lite/experimental/tf_runtime:eager_model", + # "//tensorflow/lite/experimental/tf_runtime:subgraph", ] + if_eager, if_non_eager = [ # "//tensorflow/lite/experimental/tf_runtime:interpreter", diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index d2d5eaf2cbf..70ca4d24b61 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -29,6 +29,10 @@ limitations under the License. #include "tensorflow/lite/memory_planner.h" #include "tensorflow/lite/util.h" +#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER +#include "tensorflow/lite/experimental/tf_runtime/public/subgraph.h" +#endif + namespace tflite { namespace impl { @@ -679,7 +683,11 @@ class Subgraph { } // namespace impl +#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER +using Subgraph = tflrt::Subgraph; +#else using Subgraph = impl::Subgraph; +#endif } // namespace tflite #endif // TENSORFLOW_LITE_CORE_SUBGRAPH_H_ diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index dd183b2a98f..a869c1368d2 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -34,6 +34,10 @@ limitations under the License. #include "tensorflow/lite/stderr_reporter.h" #include "tensorflow/lite/type_to_tflitetype.h" +#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER +#include "tensorflow/lite/experimental/tf_runtime/public/eager_interpreter.h" +#endif + namespace tflite { class InterpreterTest; @@ -548,7 +552,11 @@ class Interpreter { } // namespace impl +#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER +using Interpreter = tflrt::EagerInterpreter; +#else using Interpreter = impl::Interpreter; +#endif } // namespace tflite #endif // TENSORFLOW_LITE_INTERPRETER_H_ diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h index 1db7828f736..5819142ee25 100644 --- a/tensorflow/lite/model.h +++ b/tensorflow/lite/model.h @@ -19,12 +19,21 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MODEL_H_ #define TENSORFLOW_LITE_MODEL_H_ -#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/model_builder.h" +#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER +#include "tensorflow/lite/experimental/tf_runtime/lib/eager_model.h" +#else +#include "tensorflow/lite/interpreter_builder.h" +#endif + namespace tflite { +#if TFLITE_EXPERIMENTAL_RUNTIME_EAGER +using InterpreterBuilder = tflrt::EagerTfLiteInterpreterBuilderAPI; +#else using InterpreterBuilder = impl::InterpreterBuilder; +#endif } // namespace tflite diff --git a/tensorflow/lite/model_builder.cc b/tensorflow/lite/model_builder.cc index 784c39f00c8..c63ba47b5cf 100644 --- a/tensorflow/lite/model_builder.cc +++ b/tensorflow/lite/model_builder.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/allocation.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" diff --git a/tensorflow/lite/model_builder.h b/tensorflow/lite/model_builder.h index ac05223b6a8..ed5e626a77b 100644 --- a/tensorflow/lite/model_builder.h +++ b/tensorflow/lite/model_builder.h @@ -21,12 +21,13 @@ limitations under the License. #include +#include "tensorflow/lite/allocation.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/mutable_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/stderr_reporter.h" namespace tflite { diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator.h b/tensorflow/lite/tools/optimize/calibration/calibrator.h index ef7cea528d9..c726116d29b 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator.h +++ b/tensorflow/lite/tools/optimize/calibration/calibrator.h @@ -19,6 +19,7 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"