Converted py_func.i to pybind11
This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 271409413
This commit is contained in:
parent
5ac1558cdb
commit
afa2418f90
@ -19,7 +19,7 @@ visibility = [
|
|||||||
"//bazel_pip/tensorflow/lite/toco/python:__pkg__",
|
"//bazel_pip/tensorflow/lite/toco/python:__pkg__",
|
||||||
]
|
]
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test", "cc_header_only_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||||
@ -518,6 +518,24 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_header_only_library(
|
||||||
|
name = "py_func_headers_lib",
|
||||||
|
deps = [
|
||||||
|
":py_func_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_py_func",
|
||||||
|
srcs = ["lib/core/py_func_wrapper.cc"],
|
||||||
|
module_name = "_pywrap_py_func",
|
||||||
|
deps = [
|
||||||
|
":py_func_headers_lib",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "safe_ptr",
|
name = "safe_ptr",
|
||||||
srcs = ["lib/core/safe_ptr.cc"],
|
srcs = ["lib/core/safe_ptr.cc"],
|
||||||
@ -3639,6 +3657,7 @@ py_library(
|
|||||||
srcs = ["ops/script_ops.py"],
|
srcs = ["ops/script_ops.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_py_func",
|
||||||
":array_ops",
|
":array_ops",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":script_ops_gen",
|
":script_ops_gen",
|
||||||
@ -5008,7 +5027,6 @@ tf_py_wrap_cc(
|
|||||||
"grappler/tf_optimizer.i",
|
"grappler/tf_optimizer.i",
|
||||||
"lib/core/bfloat16.i",
|
"lib/core/bfloat16.i",
|
||||||
"lib/core/py_exception_registry.i",
|
"lib/core/py_exception_registry.i",
|
||||||
"lib/core/py_func.i",
|
|
||||||
"lib/core/strings.i",
|
"lib/core/strings.i",
|
||||||
"lib/io/file_io.i",
|
"lib/io/file_io.i",
|
||||||
"lib/io/py_record_reader.i",
|
"lib/io/py_record_reader.i",
|
||||||
@ -5117,6 +5135,7 @@ genrule(
|
|||||||
name = "pybind_symbol_target_libs_file",
|
name = "pybind_symbol_target_libs_file",
|
||||||
srcs = [
|
srcs = [
|
||||||
":cpp_python_util", # util
|
":cpp_python_util", # util
|
||||||
|
":py_func_lib", # py_func
|
||||||
"//tensorflow/core:util_port", # util_port
|
"//tensorflow/core:util_port", # util_port
|
||||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||||
|
@ -17,6 +17,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
|
// clang-format: off
|
||||||
|
// Must be inlcluded first.
|
||||||
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
|
// clang-format: on
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
#include "numpy/arrayobject.h"
|
#include "numpy/arrayobject.h"
|
||||||
@ -25,7 +30,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
@ -16,11 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
|
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
|
||||||
#define TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
|
#define TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
|
||||||
|
|
||||||
// Must be included first
|
#include <Python.h>
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -13,17 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
#include "include/pybind11/pybind11.h"
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/python/lib/core/py_func.h"
|
#include "tensorflow/python/lib/core/py_func.h"
|
||||||
%}
|
|
||||||
|
|
||||||
%ignoreall
|
namespace py = pybind11;
|
||||||
|
|
||||||
%unignore tensorflow;
|
PYBIND11_MODULE(_pywrap_py_func, m) {
|
||||||
%unignore tensorflow::InitializePyTrampoline;
|
m.def("initialize_py_trampoline", [](py::object trampoline) {
|
||||||
|
return tensorflow::InitializePyTrampoline(trampoline.ptr());
|
||||||
%include "tensorflow/python/lib/core/py_func.h"
|
});
|
||||||
|
}
|
||||||
%unignoreall
|
|
@ -28,7 +28,7 @@ import weakref
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_py_func
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -259,7 +259,7 @@ class FuncRegistry(object):
|
|||||||
# Global registry for py functions.
|
# Global registry for py functions.
|
||||||
_py_funcs = FuncRegistry()
|
_py_funcs = FuncRegistry()
|
||||||
|
|
||||||
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
|
_pywrap_py_func.initialize_py_trampoline(_py_funcs)
|
||||||
|
|
||||||
|
|
||||||
def _internal_py_func(func,
|
def _internal_py_func(func,
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
|
|
||||||
%include "tensorflow/python/pywrap_tfe.i"
|
%include "tensorflow/python/pywrap_tfe.i"
|
||||||
|
|
||||||
%include "tensorflow/python/lib/core/py_func.i"
|
|
||||||
%include "tensorflow/python/lib/core/py_exception_registry.i"
|
%include "tensorflow/python/lib/core/py_exception_registry.i"
|
||||||
|
|
||||||
%include "tensorflow/python/lib/io/py_record_reader.i"
|
%include "tensorflow/python/lib/io/py_record_reader.i"
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/checkpoint_reader.h"
|
#include "tensorflow/c/checkpoint_reader.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||||
#include "tensorflow/python/lib/core/py_func.h"
|
|
||||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
%{
|
%{
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/util/stat_summarizer.h"
|
#include "tensorflow/core/util/stat_summarizer.h"
|
||||||
#include "tensorflow/python/lib/core/py_func.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
@ -51,3 +51,6 @@ tensorflow::EventsWriter::InitWithSuffix
|
|||||||
tensorflow::EventsWriter::WriteSerializedEvent
|
tensorflow::EventsWriter::WriteSerializedEvent
|
||||||
tensorflow::EventsWriter::Flush
|
tensorflow::EventsWriter::Flush
|
||||||
tensorflow::EventsWriter::Close
|
tensorflow::EventsWriter::Close
|
||||||
|
|
||||||
|
[py_func_lib] # py_func
|
||||||
|
tensorflow::InitializePyTrampoline
|
||||||
|
Loading…
x
Reference in New Issue
Block a user