Converted py_func.i to pybind11
This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 271409413
This commit is contained in:
parent
5ac1558cdb
commit
afa2418f90
tensorflow
@ -19,7 +19,7 @@ visibility = [
|
||||
"//bazel_pip/tensorflow/lite/toco/python:__pkg__",
|
||||
]
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test", "cc_header_only_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
@ -518,6 +518,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "py_func_headers_lib",
|
||||
deps = [
|
||||
":py_func_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_py_func",
|
||||
srcs = ["lib/core/py_func_wrapper.cc"],
|
||||
module_name = "_pywrap_py_func",
|
||||
deps = [
|
||||
":py_func_headers_lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "safe_ptr",
|
||||
srcs = ["lib/core/safe_ptr.cc"],
|
||||
@ -3639,6 +3657,7 @@ py_library(
|
||||
srcs = ["ops/script_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_py_func",
|
||||
":array_ops",
|
||||
":framework_for_generated_wrappers",
|
||||
":script_ops_gen",
|
||||
@ -5008,7 +5027,6 @@ tf_py_wrap_cc(
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/bfloat16.i",
|
||||
"lib/core/py_exception_registry.i",
|
||||
"lib/core/py_func.i",
|
||||
"lib/core/strings.i",
|
||||
"lib/io/file_io.i",
|
||||
"lib/io/py_record_reader.i",
|
||||
@ -5117,6 +5135,7 @@ genrule(
|
||||
name = "pybind_symbol_target_libs_file",
|
||||
srcs = [
|
||||
":cpp_python_util", # util
|
||||
":py_func_lib", # py_func
|
||||
"//tensorflow/core:util_port", # util_port
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||
|
@ -17,6 +17,11 @@ limitations under the License.
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
// clang-format: off
|
||||
// Must be inlcluded first.
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
// clang-format: on
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
@ -25,7 +30,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
@ -16,11 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
|
||||
#define TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
|
||||
|
||||
// Must be included first
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include <Python.h>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,17 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
namespace py = pybind11;
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::InitializePyTrampoline;
|
||||
|
||||
%include "tensorflow/python/lib/core/py_func.h"
|
||||
|
||||
%unignoreall
|
||||
PYBIND11_MODULE(_pywrap_py_func, m) {
|
||||
m.def("initialize_py_trampoline", [](py::object trampoline) {
|
||||
return tensorflow::InitializePyTrampoline(trampoline.ptr());
|
||||
});
|
||||
}
|
@ -28,7 +28,7 @@ import weakref
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_py_func
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -259,7 +259,7 @@ class FuncRegistry(object):
|
||||
# Global registry for py functions.
|
||||
_py_funcs = FuncRegistry()
|
||||
|
||||
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
|
||||
_pywrap_py_func.initialize_py_trampoline(_py_funcs)
|
||||
|
||||
|
||||
def _internal_py_func(func,
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
|
||||
%include "tensorflow/python/pywrap_tfe.i"
|
||||
|
||||
%include "tensorflow/python/lib/core/py_func.i"
|
||||
%include "tensorflow/python/lib/core/py_exception_registry.i"
|
||||
|
||||
%include "tensorflow/python/lib/io/py_record_reader.i"
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
%}
|
||||
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
|
@ -51,3 +51,6 @@ tensorflow::EventsWriter::InitWithSuffix
|
||||
tensorflow::EventsWriter::WriteSerializedEvent
|
||||
tensorflow::EventsWriter::Flush
|
||||
tensorflow::EventsWriter::Close
|
||||
|
||||
[py_func_lib] # py_func
|
||||
tensorflow::InitializePyTrampoline
|
||||
|
Loading…
Reference in New Issue
Block a user