Export the Eager classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 286110711 Change-Id: I7bf6f6f4ce1d6bf3e8a3e40ef4a83f82333800f6
This commit is contained in:
parent
03341c4342
commit
7bd345bcbb
@ -53,6 +53,20 @@ filegroup(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"c_api_internal.h",
|
||||||
|
"tf_status_helper.h",
|
||||||
|
"tf_status_internal.h",
|
||||||
|
"tf_tensor_internal.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "c_api_internal",
|
name = "c_api_internal",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
@ -88,6 +88,18 @@ tf_cuda_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"c_api_experimental.h",
|
||||||
|
"c_api_internal.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "c_api_internal",
|
name = "c_api_internal",
|
||||||
srcs = ["c_api_experimental.h"],
|
srcs = ["c_api_experimental.h"],
|
||||||
|
@ -439,6 +439,23 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"attr_builder.h",
|
||||||
|
"context.h",
|
||||||
|
"eager_executor.h",
|
||||||
|
"eager_operation.h",
|
||||||
|
"kernel_and_device.h",
|
||||||
|
"tensor_handle.h",
|
||||||
|
"tensor_handle_data.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "srcs",
|
name = "srcs",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -783,3 +783,20 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"call_options.h",
|
||||||
|
"message_wrappers.h",
|
||||||
|
"rendezvous_mgr_interface.h",
|
||||||
|
"server_lib.h",
|
||||||
|
"worker_cache.h",
|
||||||
|
"worker_env.h",
|
||||||
|
"worker_interface.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -216,3 +216,16 @@ cc_library(
|
|||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"eager_client.h",
|
||||||
|
"remote_tensor_handle.h",
|
||||||
|
"remote_tensor_handle_data.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -853,6 +853,18 @@ tf_cc_tests(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"op_gen_lib.h",
|
||||||
|
"rendezvous.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# All framewrok protos are self-contained, i.e. they only import other
|
# All framewrok protos are self-contained, i.e. they only import other
|
||||||
# protos from the same package, so we can build the protos here and then
|
# protos from the same package, so we can build the protos here and then
|
||||||
# link them from core:protos_all without circular dependencies.
|
# link them from core:protos_all without circular dependencies.
|
||||||
|
@ -523,3 +523,14 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"profiler_interface.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -43,6 +43,17 @@ tf_cuda_library(
|
|||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"profiler_session.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "traceme",
|
name = "traceme",
|
||||||
hdrs = ["traceme.h"],
|
hdrs = ["traceme.h"],
|
||||||
|
@ -171,6 +171,7 @@ py_library(
|
|||||||
":platform",
|
":platform",
|
||||||
":proto_ops",
|
":proto_ops",
|
||||||
":pywrap_tensorflow",
|
":pywrap_tensorflow",
|
||||||
|
":pywrap_tfe",
|
||||||
":rnn_ops_gen",
|
":rnn_ops_gen",
|
||||||
":saver_test_utils",
|
":saver_test_utils",
|
||||||
":script_ops",
|
":script_ops",
|
||||||
@ -251,6 +252,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":_pywrap_util_port",
|
":_pywrap_util_port",
|
||||||
":lib",
|
":lib",
|
||||||
|
":pywrap_tfe",
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"@absl_py//absl:app",
|
"@absl_py//absl:app",
|
||||||
@ -477,13 +479,13 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "pybind11_status",
|
name = "pybind11_status",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"lib/core/py_exception_registry.h",
|
||||||
"lib/core/pybind11_status.h",
|
"lib/core/pybind11_status.h",
|
||||||
"//tensorflow/c:headers",
|
"//tensorflow/c:headers",
|
||||||
],
|
],
|
||||||
features = ["-parse_headers"],
|
features = ["-parse_headers"],
|
||||||
visibility = tf_external_workspace_visible(visibility),
|
visibility = tf_external_workspace_visible(visibility),
|
||||||
deps = [
|
deps = [
|
||||||
":py_exception_registry",
|
|
||||||
"//tensorflow/c:tf_status_headers",
|
"//tensorflow/c:tf_status_headers",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -1110,6 +1112,7 @@ py_library(
|
|||||||
":lib",
|
":lib",
|
||||||
":platform",
|
":platform",
|
||||||
":pywrap_tensorflow",
|
":pywrap_tensorflow",
|
||||||
|
":pywrap_tfe",
|
||||||
":random_seed",
|
":random_seed",
|
||||||
":sparse_tensor",
|
":sparse_tensor",
|
||||||
":tensor_spec",
|
":tensor_spec",
|
||||||
@ -5492,7 +5495,6 @@ tf_py_wrap_cc(
|
|||||||
"lib/io/py_record_reader.i",
|
"lib/io/py_record_reader.i",
|
||||||
"lib/io/py_record_writer.i",
|
"lib/io/py_record_writer.i",
|
||||||
"platform/base.i",
|
"platform/base.i",
|
||||||
"pywrap_tfe.i",
|
|
||||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||||
],
|
],
|
||||||
# add win_def_file for pywrap_tensorflow
|
# add win_def_file for pywrap_tensorflow
|
||||||
@ -5573,7 +5575,12 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
|||||||
":safe_ptr", # checkpoint_reader
|
":safe_ptr", # checkpoint_reader
|
||||||
":python_op_gen", # python_op_gen
|
":python_op_gen", # python_op_gen
|
||||||
":bfloat16_lib", # bfloat16
|
":bfloat16_lib", # bfloat16
|
||||||
|
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
|
||||||
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||||
|
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
|
||||||
|
"//tensorflow/core/common_runtime/eager:context", # tfe
|
||||||
|
"//tensorflow/core/profiler/lib:profiler_session", # tfe
|
||||||
|
"//tensorflow/c:tf_status_helper", # tfe
|
||||||
]
|
]
|
||||||
|
|
||||||
# Filter the DEF file to reduce the number of symbols to 64K or less.
|
# Filter the DEF file to reduce the number of symbols to 64K or less.
|
||||||
@ -7555,6 +7562,67 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "pywrap_tfe",
|
||||||
|
srcs = ["pywrap_tfe.py"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":_pywrap_tfe",
|
||||||
|
":pywrap_tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_tfe",
|
||||||
|
srcs = ["tfe_wrapper.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"lib/core/safe_ptr.h",
|
||||||
|
"util/util.h",
|
||||||
|
":py_exception_registry_hdr",
|
||||||
|
"//tensorflow/c:headers",
|
||||||
|
"//tensorflow/c:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/c/eager:headers",
|
||||||
|
"//tensorflow/c/eager:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/core/common_runtime/eager:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/core/distributed_runtime:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/core/distributed_runtime/eager:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/core/framework:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/core/profiler/internal:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/core/profiler/lib:pywrap_eager_hdrs",
|
||||||
|
"//tensorflow/python/eager:pywrap_eager_hdrs",
|
||||||
|
],
|
||||||
|
module_name = "_pywrap_tfe",
|
||||||
|
deps = [
|
||||||
|
":pybind11_lib",
|
||||||
|
":pybind11_status",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/hash",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@pybind11",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"//tensorflow/core:core_cpu_headers_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:platform",
|
||||||
|
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||||
|
] + if_static(
|
||||||
|
extra_deps = [
|
||||||
|
"//tensorflow/core:eager_service_proto_cc",
|
||||||
|
"//tensorflow/core:master_proto_cc",
|
||||||
|
"//tensorflow/core:worker_proto_cc",
|
||||||
|
],
|
||||||
|
otherwise = [
|
||||||
|
"//tensorflow/core:eager_service_proto_cc_headers_only",
|
||||||
|
"//tensorflow/core:master_proto_cc_headers_only",
|
||||||
|
"//tensorflow/core:worker_proto_cc_headers_only",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
tf_python_pybind_extension(
|
tf_python_pybind_extension(
|
||||||
name = "_pywrap_graph_analyzer",
|
name = "_pywrap_graph_analyzer",
|
||||||
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],
|
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
%include "tensorflow/python/lib/core/strings.i"
|
||||||
%include "tensorflow/python/platform/base.i"
|
%include "tensorflow/python/platform/base.i"
|
||||||
|
|
||||||
%{
|
%{
|
||||||
@ -23,6 +24,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/python/client/tf_session_helper.h"
|
#include "tensorflow/python/client/tf_session_helper.h"
|
||||||
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
// We were getting lucky on imports with safe_ptr.h being placed prior to
|
||||||
|
// tf_session which imported safe_ptr. We also need pywrap_tfe.h to cast
|
||||||
|
// one of the inputs to a graph function from a Python string to const char*.
|
||||||
|
|
||||||
|
|
||||||
// Helper function to convert a Python list of Tensors to a C++ vector of
|
// Helper function to convert a Python list of Tensors to a C++ vector of
|
||||||
// TF_Outputs.
|
// TF_Outputs.
|
||||||
@ -78,6 +86,9 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
|
|||||||
|
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
%include "tensorflow/c/tf_datatype.h"
|
||||||
|
%include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
%include "tensorflow/python/client/tf_sessionrun_wrapper.i"
|
%include "tensorflow/python/client/tf_sessionrun_wrapper.i"
|
||||||
|
|
||||||
// Required to use PyArray_* functions.
|
// Required to use PyArray_* functions.
|
||||||
@ -85,6 +96,14 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
|
|||||||
tensorflow::ImportNumpy();
|
tensorflow::ImportNumpy();
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||||
|
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||||
|
// Hence the 'const_cast'.
|
||||||
|
%typemap(in) const char* op_name {
|
||||||
|
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// TensorFlow version and GraphDef versions
|
// TensorFlow version and GraphDef versions
|
||||||
%constant const char* __version__ = TF_VERSION_STRING;
|
%constant const char* __version__ = TF_VERSION_STRING;
|
||||||
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
|
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
|
||||||
@ -174,6 +193,12 @@ tensorflow::ImportNumpy();
|
|||||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||||
%noexception TF_OperationGetControlInputs_wrapper;
|
%noexception TF_OperationGetControlInputs_wrapper;
|
||||||
|
|
||||||
|
|
||||||
|
// Migrate one function from pywrap_tfe.i
|
||||||
|
%include "tensorflow/c/c_api_experimental.h"
|
||||||
|
%unignore TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||||
|
%noexception TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||||
|
|
||||||
// Build a Python list of TF_Operation* and return it.
|
// Build a Python list of TF_Operation* and return it.
|
||||||
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
|
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
|
||||||
$result = PyList_New($1.size());
|
$result = PyList_New($1.size());
|
||||||
|
@ -268,7 +268,7 @@ py_library(
|
|||||||
"//tensorflow/python:device",
|
"//tensorflow/python:device",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:summary_ops_v2",
|
"//tensorflow/python:summary_ops_v2",
|
||||||
"//tensorflow/python:tensor_util",
|
"//tensorflow/python:tensor_util",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
|
@ -24,7 +24,7 @@ import functools
|
|||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.autograph.core import ag_ctx
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
from tensorflow.python.autograph.impl import api as autograph
|
from tensorflow.python.autograph.impl import api as autograph
|
||||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||||
@ -944,7 +944,7 @@ class _MirroredReplicaThread(threading.Thread):
|
|||||||
self.record_thread_local_summary_state()
|
self.record_thread_local_summary_state()
|
||||||
self.record_thread_local_eager_context_state()
|
self.record_thread_local_eager_context_state()
|
||||||
self.context_device_policy = (
|
self.context_device_policy = (
|
||||||
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
|
pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
|
||||||
ctx._context_handle)) # pylint: disable=protected-access
|
ctx._context_handle)) # pylint: disable=protected-access
|
||||||
self.graph = ops.get_default_graph()
|
self.graph = ops.get_default_graph()
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
|
@ -56,6 +56,18 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_eager_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"pywrap_tensor_conversion.h",
|
||||||
|
"pywrap_tfe.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Transitive dependencies of this target will be included in the pip package.
|
# Transitive dependencies of this target will be included in the pip package.
|
||||||
py_library(
|
py_library(
|
||||||
name = "eager_pip",
|
name = "eager_pip",
|
||||||
@ -90,7 +102,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":context",
|
":context",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,7 +112,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,7 +133,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,13 +143,14 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":eager_util",
|
||||||
":executor",
|
":executor",
|
||||||
":monitoring",
|
":monitoring",
|
||||||
"//tensorflow/python:device",
|
"//tensorflow/python:device",
|
||||||
"//tensorflow/python:device_spec",
|
"//tensorflow/python:device_spec",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:tf2",
|
"//tensorflow/python:tf2",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
@ -164,8 +177,8 @@ py_library(
|
|||||||
"//third_party/py/tf_agents:__subpackages__",
|
"//third_party/py/tf_agents:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:c_api_util",
|
":eager_util",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -187,7 +200,8 @@ py_library(
|
|||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
":context",
|
":context",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
":eager_util",
|
||||||
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -209,7 +223,8 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
":eager_util",
|
||||||
|
"//tensorflow/python:pywrap_tfe",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -298,7 +313,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -410,7 +425,7 @@ py_library(
|
|||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
@ -496,7 +511,7 @@ py_library(
|
|||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:unconnected_gradients",
|
"//tensorflow/python:unconnected_gradients",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
@ -524,7 +539,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":forwardprop_util",
|
":forwardprop_util",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -535,7 +550,18 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "eager_util",
|
||||||
|
srcs = ["eager_util.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:pywrap_tfe",
|
||||||
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -552,7 +578,7 @@ cuda_py_test(
|
|||||||
":remote",
|
":remote",
|
||||||
":test",
|
":test",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
@ -637,7 +663,7 @@ tf_py_test(
|
|||||||
":test",
|
":test",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:test_ops",
|
"//tensorflow/python:test_ops",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
@ -649,7 +675,7 @@ py_library(
|
|||||||
srcs = ["imperative_grad.py"],
|
srcs = ["imperative_grad.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tfe",
|
||||||
"//tensorflow/python:unconnected_gradients",
|
"//tensorflow/python:unconnected_gradients",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
|
@ -24,7 +24,7 @@ import sys
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python import _pywrap_utils
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.eager import backprop_util
|
from tensorflow.python.eager import backprop_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -71,19 +71,25 @@ def op_attr_type(op_type, attr_name):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
context.ensure_initialized()
|
context.ensure_initialized()
|
||||||
h = context.context()._handle # pylint: disable=protected-access
|
h = context.context()._handle # pylint: disable=protected-access
|
||||||
attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name)
|
attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
|
||||||
_op_attr_type_cache[(op_type, attr_name)] = attr_type
|
_op_attr_type_cache[(op_type, attr_name)] = attr_type
|
||||||
return attr_type
|
return attr_type
|
||||||
|
|
||||||
|
|
||||||
def make_attr(attr_type, value):
|
def make_attr(attr_type, value):
|
||||||
if attr_type == pywrap_tensorflow.TF_ATTR_TYPE:
|
# pybind11 enums do not return the raw value like SWIG enums do. They are
|
||||||
|
# useful when comparing amongst each other but not direct integers as we are
|
||||||
|
# doing in most tests.
|
||||||
|
# https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
|
||||||
|
# TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
|
||||||
|
# from integer value to class.
|
||||||
|
if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
|
||||||
return dtypes.as_dtype(value)
|
return dtypes.as_dtype(value)
|
||||||
elif attr_type == [pywrap_tensorflow.TF_ATTR_TYPE]:
|
elif attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
|
||||||
return [dtypes.as_dtype(v) for v in value]
|
return [dtypes.as_dtype(v) for v in value]
|
||||||
elif attr_type == pywrap_tensorflow.TF_ATTR_SHAPE:
|
elif attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
|
||||||
return tensor_shape.as_shape(value).as_proto()
|
return tensor_shape.as_shape(value).as_proto()
|
||||||
elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]:
|
elif attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
|
||||||
return [tensor_shape.as_shape(v).as_proto() for v in value]
|
return [tensor_shape.as_shape(v).as_proto() for v in value]
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
return value.encode()
|
return value.encode()
|
||||||
@ -141,16 +147,15 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
|
|||||||
return grad_fn(mock_op, *out_grads)
|
return grad_fn(mock_op, *out_grads)
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
|
pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
|
||||||
|
|
||||||
|
|
||||||
def _must_record_gradient():
|
def _must_record_gradient():
|
||||||
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
|
return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
|
||||||
|
|
||||||
|
|
||||||
def _record_gradient(op_name, inputs, attrs, results):
|
def _record_gradient(op_name, inputs, attrs, results):
|
||||||
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
|
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
|
||||||
results)
|
|
||||||
|
|
||||||
|
|
||||||
execute.must_record_gradient = _must_record_gradient
|
execute.must_record_gradient = _must_record_gradient
|
||||||
@ -688,7 +693,7 @@ _default_vspace = imperative_grad.VSpace(
|
|||||||
zeros_like_fn=default_gradient.zeros_like,
|
zeros_like_fn=default_gradient.zeros_like,
|
||||||
ones_like_fn=default_gradient.ones_like,
|
ones_like_fn=default_gradient.ones_like,
|
||||||
graph_shape_fn=gen_array_ops.shape)
|
graph_shape_fn=gen_array_ops.shape)
|
||||||
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
|
pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
|
||||||
|
|
||||||
|
|
||||||
def _handle_or_self(x):
|
def _handle_or_self(x):
|
||||||
|
@ -21,7 +21,7 @@ import functools
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -1014,19 +1014,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
def testGetAttrType(self):
|
def testGetAttrType(self):
|
||||||
typ = backprop.op_attr_type('Add', 'T')
|
typ = backprop.op_attr_type('Add', 'T')
|
||||||
self.assertEqual(typ, pywrap_tensorflow.TF_ATTR_TYPE)
|
self.assertEqual(typ, int(pywrap_tfe.TF_ATTR_TYPE))
|
||||||
|
|
||||||
def testGetAttrList(self):
|
def testGetAttrList(self):
|
||||||
typ = backprop.op_attr_type('MaxPool', 'ksize')
|
typ = backprop.op_attr_type('MaxPool', 'ksize')
|
||||||
self.assertEqual(typ, [pywrap_tensorflow.TF_ATTR_INT])
|
self.assertEqual(typ, [int(pywrap_tfe.TF_ATTR_INT)])
|
||||||
|
|
||||||
def testMakeAttrType(self):
|
def testMakeAttrType(self):
|
||||||
self.assertEqual(dtypes.float32,
|
self.assertEqual(dtypes.float32,
|
||||||
backprop.make_attr(pywrap_tensorflow.TF_ATTR_TYPE, 1))
|
backprop.make_attr(int(pywrap_tfe.TF_ATTR_TYPE), 1))
|
||||||
|
|
||||||
def testMakeAttrTypeList(self):
|
def testMakeAttrTypeList(self):
|
||||||
self.assertEqual([dtypes.float32],
|
self.assertEqual([dtypes.float32],
|
||||||
backprop.make_attr([pywrap_tensorflow.TF_ATTR_TYPE], [1]))
|
backprop.make_attr([int(pywrap_tfe.TF_ATTR_TYPE)], [1]))
|
||||||
|
|
||||||
def testMulType(self):
|
def testMulType(self):
|
||||||
|
|
||||||
@ -1040,7 +1040,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
def testMakeAttrShape(self):
|
def testMakeAttrShape(self):
|
||||||
for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
|
for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
|
||||||
expected = tensor_shape.TensorShape(s).as_proto()
|
expected = tensor_shape.TensorShape(s).as_proto()
|
||||||
actual = backprop.make_attr(pywrap_tensorflow.TF_ATTR_SHAPE, s)
|
actual = backprop.make_attr(int(pywrap_tfe.TF_ATTR_SHAPE), s)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expected,
|
expected,
|
||||||
actual,
|
actual,
|
||||||
@ -1051,7 +1051,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
|
shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[tensor_shape.TensorShape(s).as_proto() for s in shape_list],
|
[tensor_shape.TensorShape(s).as_proto() for s in shape_list],
|
||||||
backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list))
|
backprop.make_attr([int(pywrap_tfe.TF_ATTR_SHAPE)], shape_list))
|
||||||
|
|
||||||
def testArgsGradientFunction(self):
|
def testArgsGradientFunction(self):
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ import six
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import backprop # pylint: disable=unused-import
|
from tensorflow.python.eager import backprop # pylint: disable=unused-import
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -76,10 +76,10 @@ def c_tfe_py_fastpath_execute(a,
|
|||||||
assert ctx.executing_eagerly(
|
assert ctx.executing_eagerly(
|
||||||
), "The prototype doesn't contain C code for graph construction"
|
), "The prototype doesn't contain C code for graph construction"
|
||||||
try:
|
try:
|
||||||
return pywrap_tensorflow.TFE_Py_FastPathExecute(
|
return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "MatMul", name,
|
"MatMul", name, ctx.op_callbacks,
|
||||||
ctx.op_callbacks, a, b, "transpose_a", transpose_a,
|
a, b, "transpose_a", transpose_a,
|
||||||
"transpose_b", transpose_b)
|
"transpose_b", transpose_b)
|
||||||
except core._NotOkStatusException as e:
|
except core._NotOkStatusException as e:
|
||||||
if name is not None:
|
if name is not None:
|
||||||
message = e.message + " name: " + name
|
message = e.message + " name: " + name
|
||||||
@ -339,8 +339,7 @@ class MicroBenchmarks(test.Benchmark):
|
|||||||
inputs = [m]
|
inputs = [m]
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "Identity", inputs,
|
pywrap_tfe.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, attrs, 1)
|
||||||
attrs, 1)
|
|
||||||
|
|
||||||
self._run(f, 30000)
|
self._run(f, 30000)
|
||||||
|
|
||||||
@ -406,8 +405,7 @@ class MicroBenchmarks(test.Benchmark):
|
|||||||
m.dtype.as_datatype_enum)
|
m.dtype.as_datatype_enum)
|
||||||
|
|
||||||
def func():
|
def func():
|
||||||
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs,
|
pywrap_tfe.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, attrs, 1)
|
||||||
attrs, 1)
|
|
||||||
|
|
||||||
self._run(func, num_iters)
|
self._run(func, num_iters)
|
||||||
|
|
||||||
|
@ -18,27 +18,27 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
|
|
||||||
|
|
||||||
class CancellationManager(object):
|
class CancellationManager(object):
|
||||||
"""A mechanism for cancelling blocking computation."""
|
"""A mechanism for cancelling blocking computation."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._impl = pywrap_tensorflow.TFE_NewCancellationManager()
|
self._impl = pywrap_tfe.TFE_NewCancellationManager()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_cancelled(self):
|
def is_cancelled(self):
|
||||||
"""Returns `True` if `CancellationManager.start_cancel` has been called."""
|
"""Returns `True` if `CancellationManager.start_cancel` has been called."""
|
||||||
return pywrap_tensorflow.TFE_CancellationManagerIsCancelled(self._impl)
|
return pywrap_tfe.TFE_CancellationManagerIsCancelled(self._impl)
|
||||||
|
|
||||||
def start_cancel(self):
|
def start_cancel(self):
|
||||||
"""Cancels blocking operations that have been registered with this object."""
|
"""Cancels blocking operations that have been registered with this object."""
|
||||||
pywrap_tensorflow.TFE_CancellationManagerStartCancel(self._impl)
|
pywrap_tfe.TFE_CancellationManagerStartCancel(self._impl)
|
||||||
|
|
||||||
def get_cancelable_function(self, concrete_function):
|
def get_cancelable_function(self, concrete_function):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return concrete_function._experimental_with_cancellation_manager(self)
|
return concrete_function._experimental_with_cancellation_manager(self)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
pywrap_tensorflow.TFE_DeleteCancellationManager(self._impl)
|
pywrap_tfe.TFE_DeleteCancellationManager(self._impl)
|
||||||
|
@ -29,11 +29,11 @@ import six
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.eager import eager_util as c_api_util
|
||||||
from tensorflow.python.eager import executor
|
from tensorflow.python.eager import executor
|
||||||
from tensorflow.python.eager import monitoring
|
from tensorflow.python.eager import monitoring
|
||||||
from tensorflow.python.framework import c_api_util
|
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import is_in_graph_mode
|
from tensorflow.python.util import is_in_graph_mode
|
||||||
@ -54,17 +54,17 @@ _starting_device_spec = pydev.DeviceSpec.from_string("")
|
|||||||
|
|
||||||
_MAXINT32 = 2**31 - 1
|
_MAXINT32 = 2**31 - 1
|
||||||
|
|
||||||
DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
|
DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
|
||||||
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
|
DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
|
||||||
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
|
DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
|
||||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
|
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
|
||||||
pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
|
pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
|
||||||
|
|
||||||
SYNC = 0
|
SYNC = 0
|
||||||
ASYNC = 1
|
ASYNC = 1
|
||||||
|
|
||||||
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
|
MIRRORING_NONE = pywrap_tfe.TFE_MIRRORING_NONE
|
||||||
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
|
MIRRORING_ALL = pywrap_tfe.TFE_MIRRORING_ALL
|
||||||
|
|
||||||
_KEEP_ALIVE_SECS = 600
|
_KEEP_ALIVE_SECS = 600
|
||||||
|
|
||||||
@ -444,7 +444,7 @@ class Context(object):
|
|||||||
self._rng = random.Random(seed)
|
self._rng = random.Random(seed)
|
||||||
# Also clear the kernel cache, to reset any existing seeds
|
# Also clear the kernel cache, to reset any existing seeds
|
||||||
if self._context_handle is not None:
|
if self._context_handle is not None:
|
||||||
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
|
pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
|
||||||
|
|
||||||
def _internal_operation_seed(self):
|
def _internal_operation_seed(self):
|
||||||
"""Returns a fake operation seed.
|
"""Returns a fake operation seed.
|
||||||
@ -463,12 +463,11 @@ class Context(object):
|
|||||||
# Store list of devices
|
# Store list of devices
|
||||||
logical_devices = []
|
logical_devices = []
|
||||||
context_devices = []
|
context_devices = []
|
||||||
device_list = pywrap_tensorflow.TFE_ContextListDevices(
|
device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
|
||||||
self._context_handle)
|
|
||||||
try:
|
try:
|
||||||
self._num_gpus = 0
|
self._num_gpus = 0
|
||||||
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
|
for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
|
||||||
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
|
dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
|
||||||
context_devices.append(pydev.canonical_name(dev_name))
|
context_devices.append(pydev.canonical_name(dev_name))
|
||||||
spec = pydev.DeviceSpec.from_string(dev_name)
|
spec = pydev.DeviceSpec.from_string(dev_name)
|
||||||
# If the job is localhost, we assume that the cluster has not yet been
|
# If the job is localhost, we assume that the cluster has not yet been
|
||||||
@ -477,14 +476,14 @@ class Context(object):
|
|||||||
spec = spec.replace(job=None, replica=None, task=None)
|
spec = spec.replace(job=None, replica=None, task=None)
|
||||||
logical_devices.append(
|
logical_devices.append(
|
||||||
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
|
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
|
||||||
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
|
dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
|
||||||
if dev_type == "GPU":
|
if dev_type == "GPU":
|
||||||
self._num_gpus += 1
|
self._num_gpus += 1
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._logical_devices = logical_devices
|
self._logical_devices = logical_devices
|
||||||
self._context_devices = context_devices
|
self._context_devices = context_devices
|
||||||
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
|
pywrap_tfe.TF_DeleteDeviceList(device_list)
|
||||||
|
|
||||||
def ensure_initialized(self):
|
def ensure_initialized(self):
|
||||||
"""Initialize handle and devices if not already done so."""
|
"""Initialize handle and devices if not already done so."""
|
||||||
@ -494,36 +493,34 @@ class Context(object):
|
|||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
assert self._context_devices is None
|
assert self._context_devices is None
|
||||||
opts = pywrap_tensorflow.TFE_NewContextOptions()
|
opts = pywrap_tfe.TFE_NewContextOptions()
|
||||||
try:
|
try:
|
||||||
config_str = self.config.SerializeToString()
|
config_str = self.config.SerializeToString()
|
||||||
pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
|
pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
|
||||||
if self._device_policy is not None:
|
if self._device_policy is not None:
|
||||||
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
|
pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||||
opts, self._device_policy)
|
opts, self._device_policy)
|
||||||
if self._mirroring_policy is not None:
|
if self._mirroring_policy is not None:
|
||||||
pywrap_tensorflow.TFE_ContextOptionsSetMirroringPolicy(
|
pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
|
||||||
opts, self._mirroring_policy)
|
opts, self._mirroring_policy)
|
||||||
if self._default_is_async == ASYNC:
|
if self._default_is_async == ASYNC:
|
||||||
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
|
pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
|
||||||
if self._lazy_remote_inputs_copy is not None:
|
if self._lazy_remote_inputs_copy is not None:
|
||||||
pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||||
opts, self._lazy_remote_inputs_copy)
|
opts, self._lazy_remote_inputs_copy)
|
||||||
context_handle = pywrap_tensorflow.TFE_NewContext(opts)
|
context_handle = pywrap_tfe.TFE_NewContext(opts)
|
||||||
finally:
|
finally:
|
||||||
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
|
pywrap_tfe.TFE_DeleteContextOptions(opts)
|
||||||
assert not (self._server_def and self._collective_ops_server_def), (
|
assert not (self._server_def and self._collective_ops_server_def), (
|
||||||
"Cannot enable remote execution as well as collective ops at the "
|
"Cannot enable remote execution as well as collective ops at the "
|
||||||
"moment. If this is important to you, please file an issue.")
|
"moment. If this is important to you, please file an issue.")
|
||||||
if self._server_def is not None:
|
if self._server_def is not None:
|
||||||
server_def_str = self._server_def.SerializeToString()
|
server_def_str = self._server_def.SerializeToString()
|
||||||
pywrap_tensorflow.TFE_ContextSetServerDef(context_handle,
|
pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
|
||||||
_KEEP_ALIVE_SECS,
|
server_def_str)
|
||||||
server_def_str)
|
|
||||||
elif self._collective_ops_server_def is not None:
|
elif self._collective_ops_server_def is not None:
|
||||||
server_def_str = self._collective_ops_server_def.SerializeToString()
|
server_def_str = self._collective_ops_server_def.SerializeToString()
|
||||||
pywrap_tensorflow.TFE_EnableCollectiveOps(context_handle,
|
pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
|
||||||
server_def_str)
|
|
||||||
|
|
||||||
self._context_handle = context_handle
|
self._context_handle = context_handle
|
||||||
self._initialize_logical_devices()
|
self._initialize_logical_devices()
|
||||||
@ -532,7 +529,7 @@ class Context(object):
|
|||||||
def _clear_caches(self):
|
def _clear_caches(self):
|
||||||
self.ones_rank_cache().flush()
|
self.ones_rank_cache().flush()
|
||||||
self.zeros_cache().flush()
|
self.zeros_cache().flush()
|
||||||
pywrap_tensorflow.TFE_ClearScalarCache()
|
pywrap_tfe.TFE_ClearScalarCache()
|
||||||
|
|
||||||
def get_server_def(self):
|
def get_server_def(self):
|
||||||
return self._server_def
|
return self._server_def
|
||||||
@ -563,8 +560,8 @@ class Context(object):
|
|||||||
|
|
||||||
if self._context_handle:
|
if self._context_handle:
|
||||||
server_def_str = server_def.SerializeToString()
|
server_def_str = server_def.SerializeToString()
|
||||||
pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
|
pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
|
||||||
keep_alive_secs, server_def_str)
|
server_def_str)
|
||||||
self._initialize_logical_devices()
|
self._initialize_logical_devices()
|
||||||
|
|
||||||
# Clear all the caches in case there are remote tensors in them.
|
# Clear all the caches in case there are remote tensors in them.
|
||||||
@ -592,9 +589,8 @@ class Context(object):
|
|||||||
|
|
||||||
if self._context_handle:
|
if self._context_handle:
|
||||||
server_def_str = server_def.SerializeToString()
|
server_def_str = server_def.SerializeToString()
|
||||||
pywrap_tensorflow.TFE_ContextUpdateServerDef(self._context_handle,
|
pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
|
||||||
keep_alive_secs,
|
keep_alive_secs, server_def_str)
|
||||||
server_def_str)
|
|
||||||
self._initialize_logical_devices()
|
self._initialize_logical_devices()
|
||||||
|
|
||||||
self._clear_caches()
|
self._clear_caches()
|
||||||
@ -614,8 +610,7 @@ class Context(object):
|
|||||||
"""
|
"""
|
||||||
# TODO(yuefengz): support checking multiple workers.
|
# TODO(yuefengz): support checking multiple workers.
|
||||||
if self._context_handle:
|
if self._context_handle:
|
||||||
return pywrap_tensorflow.TFE_ContextCheckAlive(self._context_handle,
|
return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
|
||||||
worker_name)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Context is not initialized.")
|
raise ValueError("Context is not initialized.")
|
||||||
|
|
||||||
@ -808,8 +803,8 @@ class Context(object):
|
|||||||
self.executor.wait()
|
self.executor.wait()
|
||||||
executor_new = executor.new_executor(enable_async)
|
executor_new = executor.new_executor(enable_async)
|
||||||
self._thread_local_data.executor = executor_new
|
self._thread_local_data.executor = executor_new
|
||||||
pywrap_tensorflow.TFE_ContextSetExecutorForThread(
|
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
|
||||||
self._context_handle, executor_new.handle())
|
executor_new.handle())
|
||||||
else:
|
else:
|
||||||
self._default_is_async = enable_async
|
self._default_is_async = enable_async
|
||||||
|
|
||||||
@ -823,13 +818,12 @@ class Context(object):
|
|||||||
def executor(self):
|
def executor(self):
|
||||||
ensure_initialized()
|
ensure_initialized()
|
||||||
return executor.Executor(
|
return executor.Executor(
|
||||||
pywrap_tensorflow.TFE_ContextGetExecutorForThread(self._context_handle))
|
pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
|
||||||
|
|
||||||
@executor.setter
|
@executor.setter
|
||||||
def executor(self, e):
|
def executor(self, e):
|
||||||
ensure_initialized()
|
ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_ContextSetExecutorForThread(self._context_handle,
|
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
|
||||||
e.handle())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
@ -1015,7 +1009,7 @@ class Context(object):
|
|||||||
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
|
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
|
||||||
"""
|
"""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)
|
pywrap_tfe.TFE_ContextAddFunction(self._handle, fn)
|
||||||
|
|
||||||
def add_function_def(self, fdef):
|
def add_function_def(self, fdef):
|
||||||
"""Add a function definition to the context.
|
"""Add a function definition to the context.
|
||||||
@ -1028,8 +1022,8 @@ class Context(object):
|
|||||||
"""
|
"""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
fdef_string = fdef.SerializeToString()
|
fdef_string = fdef.SerializeToString()
|
||||||
pywrap_tensorflow.TFE_ContextAddFunctionDef(
|
pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
|
||||||
self._handle, fdef_string, len(fdef_string))
|
len(fdef_string))
|
||||||
|
|
||||||
def remove_function(self, name):
|
def remove_function(self, name):
|
||||||
"""Remove a function from the context.
|
"""Remove a function from the context.
|
||||||
@ -1040,12 +1034,12 @@ class Context(object):
|
|||||||
name: function signature name.
|
name: function signature name.
|
||||||
"""
|
"""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_ContextRemoveFunction(self._handle, name)
|
pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
|
||||||
|
|
||||||
def has_function(self, name):
|
def has_function(self, name):
|
||||||
"""Check if a function `name` is registered."""
|
"""Check if a function `name` is registered."""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
return bool(pywrap_tensorflow.TFE_ContextHasFunction(self._handle, name))
|
return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
|
||||||
|
|
||||||
def add_op_callback(self, callback):
|
def add_op_callback(self, callback):
|
||||||
"""Add a post-op callback to the context.
|
"""Add a post-op callback to the context.
|
||||||
@ -1101,7 +1095,7 @@ class Context(object):
|
|||||||
if self._physical_devices is not None:
|
if self._physical_devices is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
devs = pywrap_tensorflow.TF_ListPhysicalDevices()
|
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||||
self._physical_devices = [
|
self._physical_devices = [
|
||||||
PhysicalDevice(name=d.decode(),
|
PhysicalDevice(name=d.decode(),
|
||||||
device_type=d.decode().split(":")[1]) for d in devs]
|
device_type=d.decode().split(":")[1]) for d in devs]
|
||||||
@ -1434,7 +1428,7 @@ class Context(object):
|
|||||||
def device_policy(self):
|
def device_policy(self):
|
||||||
# Only get the policy from the context if it has already been initialized
|
# Only get the policy from the context if it has already been initialized
|
||||||
if self._context_handle is not None:
|
if self._context_handle is not None:
|
||||||
return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(self._handle)
|
return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
|
||||||
|
|
||||||
return self._device_policy
|
return self._device_policy
|
||||||
|
|
||||||
@ -1448,14 +1442,14 @@ class Context(object):
|
|||||||
|
|
||||||
# Only set the policy if the context has already been initialized
|
# Only set the policy if the context has already been initialized
|
||||||
if self._context_handle is not None:
|
if self._context_handle is not None:
|
||||||
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||||
self._handle, self._device_policy)
|
self._handle, self._device_policy)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mirroring_policy(self):
|
def mirroring_policy(self):
|
||||||
# Only get the policy from the context if it has already been initialized
|
# Only get the policy from the context if it has already been initialized
|
||||||
if self._context_handle is not None:
|
if self._context_handle is not None:
|
||||||
return pywrap_tensorflow.TFE_ContextGetMirroringPolicy(self._handle)
|
return pywrap_tfe.TFE_ContextGetMirroringPolicy(self._handle)
|
||||||
|
|
||||||
return self._mirroring_policy
|
return self._mirroring_policy
|
||||||
|
|
||||||
@ -1469,7 +1463,7 @@ class Context(object):
|
|||||||
|
|
||||||
# Only set the policy if the context has already been initialized
|
# Only set the policy if the context has already been initialized
|
||||||
if self._context_handle is not None:
|
if self._context_handle is not None:
|
||||||
pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy(
|
pywrap_tfe.TFE_ContextSetThreadLocalMirroringPolicy(
|
||||||
self._handle, self._mirroring_policy)
|
self._handle, self._mirroring_policy)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1495,13 +1489,13 @@ class Context(object):
|
|||||||
and to stop tracing call context.disable_run_metadata().
|
and to stop tracing call context.disable_run_metadata().
|
||||||
"""
|
"""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
|
pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
|
||||||
|
|
||||||
def disable_run_metadata(self):
|
def disable_run_metadata(self):
|
||||||
"""Disables tracing of op execution via RunMetadata."""
|
"""Disables tracing of op execution via RunMetadata."""
|
||||||
if not self._context_handle:
|
if not self._context_handle:
|
||||||
return
|
return
|
||||||
pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
|
pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
|
||||||
|
|
||||||
def enable_graph_collection(self):
|
def enable_graph_collection(self):
|
||||||
"""Enables graph collection of executed functions.
|
"""Enables graph collection of executed functions.
|
||||||
@ -1510,13 +1504,13 @@ class Context(object):
|
|||||||
and to stop collecting graphs call context.disable_graph_collection().
|
and to stop collecting graphs call context.disable_graph_collection().
|
||||||
"""
|
"""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_ContextEnableGraphCollection(self._handle)
|
pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
|
||||||
|
|
||||||
def disable_graph_collection(self):
|
def disable_graph_collection(self):
|
||||||
"""Disables graph collection of executed functions."""
|
"""Disables graph collection of executed functions."""
|
||||||
if not self._context_handle:
|
if not self._context_handle:
|
||||||
return
|
return
|
||||||
pywrap_tensorflow.TFE_ContextDisableGraphCollection(self._context_handle)
|
pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
|
||||||
|
|
||||||
def export_run_metadata(self):
|
def export_run_metadata(self):
|
||||||
"""Returns a RunMetadata proto with accumulated information.
|
"""Returns a RunMetadata proto with accumulated information.
|
||||||
@ -1530,9 +1524,8 @@ class Context(object):
|
|||||||
if not self._context_handle:
|
if not self._context_handle:
|
||||||
return None
|
return None
|
||||||
with c_api_util.tf_buffer() as buffer_:
|
with c_api_util.tf_buffer() as buffer_:
|
||||||
pywrap_tensorflow.TFE_ContextExportRunMetadata(
|
pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
|
||||||
self._context_handle, buffer_)
|
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
|
||||||
run_metadata = config_pb2.RunMetadata()
|
run_metadata = config_pb2.RunMetadata()
|
||||||
run_metadata.ParseFromString(compat.as_bytes(proto_data))
|
run_metadata.ParseFromString(compat.as_bytes(proto_data))
|
||||||
return run_metadata
|
return run_metadata
|
||||||
@ -1543,10 +1536,10 @@ class Context(object):
|
|||||||
return self._context_switches
|
return self._context_switches
|
||||||
|
|
||||||
def start_step(self):
|
def start_step(self):
|
||||||
pywrap_tensorflow.TFE_ContextStartStep(self._handle)
|
pywrap_tfe.TFE_ContextStartStep(self._handle)
|
||||||
|
|
||||||
def end_step(self):
|
def end_step(self):
|
||||||
pywrap_tensorflow.TFE_ContextEndStep(self._handle)
|
pywrap_tfe.TFE_ContextEndStep(self._handle)
|
||||||
|
|
||||||
|
|
||||||
class _EagerDeviceContext(object):
|
class _EagerDeviceContext(object):
|
||||||
@ -1608,7 +1601,7 @@ _context_lock = threading.Lock()
|
|||||||
|
|
||||||
def _set_context_locked(ctx):
|
def _set_context_locked(ctx):
|
||||||
global _context
|
global _context
|
||||||
pywrap_tensorflow.TFE_Py_SetEagerContext(ctx)
|
pywrap_tfe.TFE_Py_SetEagerContext(ctx)
|
||||||
_context = ctx
|
_context = ctx
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
|
||||||
# Trace of execution and memory usage.
|
# Trace of execution and memory usage.
|
||||||
@ -46,7 +46,7 @@ class _NotOkStatusException(Exception):
|
|||||||
return "%s: %s" % (e.__class__.__name__, e)
|
return "%s: %s" % (e.__class__.__name__, e)
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
|
pywrap_tfe.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
|
||||||
|
|
||||||
|
|
||||||
class _FallbackException(Exception):
|
class _FallbackException(Exception):
|
||||||
@ -71,4 +71,4 @@ class _SymbolicException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)
|
pywrap_tfe.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)
|
||||||
|
@ -26,7 +26,7 @@ import threading
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -602,8 +602,8 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testRegisterExceptionClass(self):
|
def testRegisterExceptionClass(self):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
|
pywrap_tfe.TFE_Py_RegisterExceptionClass(str)
|
||||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
|
pywrap_tfe.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
|
||||||
|
|
||||||
# TODO(agarwal): add tests passing incorrect typed values to attrs.
|
# TODO(agarwal): add tests passing incorrect typed values to attrs.
|
||||||
def testExecuteBasic(self):
|
def testExecuteBasic(self):
|
||||||
|
61
tensorflow/python/eager/eager_util.py
Normal file
61
tensorflow/python/eager/eager_util.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities for using the TensorFlow Eager using the C API."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tfe as c_api
|
||||||
|
from tensorflow.python.util import compat
|
||||||
|
from tensorflow.python.util import tf_contextlib
|
||||||
|
|
||||||
|
|
||||||
|
# We temporarily need a duplicate tf_buffer function in eager_util. The
|
||||||
|
# c_api_util is still relying on SWIG and is thus incompatible until
|
||||||
|
# we migrate over. We can delete this once we migrate tf_session.i
|
||||||
|
|
||||||
|
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
|
def tf_buffer(data=None):
|
||||||
|
"""Context manager that creates and deletes TF_Buffer.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
with tf_buffer() as buf:
|
||||||
|
# get serialized graph def into buf
|
||||||
|
...
|
||||||
|
proto_data = c_api.TF_GetBuffer(buf)
|
||||||
|
graph_def.ParseFromString(compat.as_bytes(proto_data))
|
||||||
|
# buf has been deleted
|
||||||
|
|
||||||
|
with tf_buffer(some_string) as buf:
|
||||||
|
c_api.TF_SomeFunction(buf)
|
||||||
|
# buf has been deleted
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
|
||||||
|
yielded buffer will contain this data.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Created TF_Buffer
|
||||||
|
"""
|
||||||
|
if data:
|
||||||
|
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
|
||||||
|
else:
|
||||||
|
buf = c_api.TF_NewBuffer()
|
||||||
|
try:
|
||||||
|
yield buf
|
||||||
|
finally:
|
||||||
|
c_api.TF_DeleteBuffer(buf)
|
@ -22,7 +22,7 @@ import six
|
|||||||
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tensorflow.core.framework import tensor_pb2
|
from tensorflow.core.framework import tensor_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -56,9 +56,8 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
try:
|
try:
|
||||||
ctx.ensure_initialized()
|
ctx.ensure_initialized()
|
||||||
tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
|
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
|
||||||
op_name, inputs, attrs,
|
inputs, attrs, num_outputs)
|
||||||
num_outputs)
|
|
||||||
except core._NotOkStatusException as e:
|
except core._NotOkStatusException as e:
|
||||||
if name is not None:
|
if name is not None:
|
||||||
message = e.message + " name: " + name
|
message = e.message + " name: " + name
|
||||||
@ -111,9 +110,10 @@ def execute_with_cancellation(op_name,
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
try:
|
try:
|
||||||
ctx.ensure_initialized()
|
ctx.ensure_initialized()
|
||||||
tensors = pywrap_tensorflow.TFE_Py_ExecuteCancelable(
|
tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
|
||||||
ctx._handle, device_name, op_name, inputs, attrs,
|
op_name, inputs, attrs,
|
||||||
cancellation_manager._impl, num_outputs)
|
cancellation_manager._impl,
|
||||||
|
num_outputs)
|
||||||
except core._NotOkStatusException as e:
|
except core._NotOkStatusException as e:
|
||||||
if name is not None:
|
if name is not None:
|
||||||
message = e.message + " name: " + name
|
message = e.message + " name: " + name
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
|
|
||||||
|
|
||||||
class Executor(object):
|
class Executor(object):
|
||||||
@ -45,8 +45,8 @@ class Executor(object):
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
try:
|
try:
|
||||||
# pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
# pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
||||||
pywrap_tensorflow.TFE_DeleteExecutor(self._handle)
|
pywrap_tfe.TFE_DeleteExecutor(self._handle)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Suppress some exceptions, mainly for the case when we're running on
|
# Suppress some exceptions, mainly for the case when we're running on
|
||||||
# module deletion. Things that can go wrong include the pywrap module
|
# module deletion. Things that can go wrong include the pywrap module
|
||||||
@ -57,20 +57,20 @@ class Executor(object):
|
|||||||
# partially unloaded.
|
# partially unloaded.
|
||||||
|
|
||||||
def is_async(self):
|
def is_async(self):
|
||||||
return pywrap_tensorflow.TFE_ExecutorIsAsync(self._handle)
|
return pywrap_tfe.TFE_ExecutorIsAsync(self._handle)
|
||||||
|
|
||||||
def handle(self):
|
def handle(self):
|
||||||
return self._handle
|
return self._handle
|
||||||
|
|
||||||
def wait(self):
|
def wait(self):
|
||||||
"""Waits for ops dispatched in this executor to finish."""
|
"""Waits for ops dispatched in this executor to finish."""
|
||||||
pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
||||||
|
|
||||||
def clear_error(self):
|
def clear_error(self):
|
||||||
"""Clears errors raised in this executor during execution."""
|
"""Clears errors raised in this executor during execution."""
|
||||||
pywrap_tensorflow.TFE_ExecutorClearError(self._handle)
|
pywrap_tfe.TFE_ExecutorClearError(self._handle)
|
||||||
|
|
||||||
|
|
||||||
def new_executor(enable_async):
|
def new_executor(enable_async):
|
||||||
handle = pywrap_tensorflow.TFE_NewExecutor(enable_async)
|
handle = pywrap_tfe.TFE_NewExecutor(enable_async)
|
||||||
return Executor(handle)
|
return Executor(handle)
|
||||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import backprop_util
|
from tensorflow.python.eager import backprop_util
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -166,7 +166,8 @@ def _jvp_dispatch(op_name, attr_tuple, inputs, outputs, tangents):
|
|||||||
return _jvp_relaxed_shapes(
|
return _jvp_relaxed_shapes(
|
||||||
op_name, attr_tuple, inputs, outputs, tangents)
|
op_name, attr_tuple, inputs, outputs, tangents)
|
||||||
|
|
||||||
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
|
|
||||||
|
pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("autodiff.ForwardAccumulator", v1=[])
|
@tf_export("autodiff.ForwardAccumulator", v1=[])
|
||||||
@ -300,7 +301,7 @@ class ForwardAccumulator(object):
|
|||||||
ValueError: If the same tensor or variable is specified multiple times in
|
ValueError: If the same tensor or variable is specified multiple times in
|
||||||
`primals`.
|
`primals`.
|
||||||
"""
|
"""
|
||||||
self._accumulator = pywrap_tensorflow.TFE_Py_ForwardAccumulatorNew()
|
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
|
||||||
self._recording = False
|
self._recording = False
|
||||||
primal_ids = set()
|
primal_ids = set()
|
||||||
for primal in nest.flatten(primals):
|
for primal in nest.flatten(primals):
|
||||||
@ -323,13 +324,13 @@ class ForwardAccumulator(object):
|
|||||||
def _push_accumulator(self):
|
def _push_accumulator(self):
|
||||||
if self._recording:
|
if self._recording:
|
||||||
raise ValueError("Accumulator is already recording.")
|
raise ValueError("Accumulator is already recording.")
|
||||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
|
pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
|
||||||
self._recording = True
|
self._recording = True
|
||||||
|
|
||||||
def _pop_accumulator(self):
|
def _pop_accumulator(self):
|
||||||
if not self._recording:
|
if not self._recording:
|
||||||
raise ValueError("Accumulator is not recording.")
|
raise ValueError("Accumulator is not recording.")
|
||||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
|
pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
|
||||||
self._recording = False
|
self._recording = False
|
||||||
|
|
||||||
def _watch(self, primals, tangents):
|
def _watch(self, primals, tangents):
|
||||||
@ -358,7 +359,7 @@ class ForwardAccumulator(object):
|
|||||||
# Run convert_to_tensor to get the captured handle from whichever
|
# Run convert_to_tensor to get the captured handle from whichever
|
||||||
# function we're running if necessary.
|
# function we're running if necessary.
|
||||||
t = ops.convert_to_tensor(t.handle)
|
t = ops.convert_to_tensor(t.handle)
|
||||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
|
pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
|
||||||
|
|
||||||
def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
|
def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
|
||||||
"""Fetches the Jacobian-vector product computed for `primals`.
|
"""Fetches the Jacobian-vector product computed for `primals`.
|
||||||
@ -384,8 +385,8 @@ class ForwardAccumulator(object):
|
|||||||
def _fetch_jvp(tensor):
|
def _fetch_jvp(tensor):
|
||||||
if hasattr(tensor, "handle"):
|
if hasattr(tensor, "handle"):
|
||||||
tensor = ops.convert_to_tensor(tensor.handle)
|
tensor = ops.convert_to_tensor(tensor.handle)
|
||||||
result = pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP(
|
result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
|
||||||
self._accumulator, tensor)
|
tensor)
|
||||||
if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
|
if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
|
||||||
return array_ops.zeros_like(tensor)
|
return array_ops.zeros_like(tensor)
|
||||||
return result
|
return result
|
||||||
|
@ -24,7 +24,7 @@ import weakref
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -236,13 +236,13 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
x = constant_op.constant(1.)
|
x = constant_op.constant(1.)
|
||||||
with forwardprop.ForwardAccumulator(x, 2.) as acc:
|
with forwardprop.ForwardAccumulator(x, 2.) as acc:
|
||||||
y = x + x
|
y = x + x
|
||||||
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(
|
pywrap_tfe.TFE_Py_RegisterJVPFunction(
|
||||||
lambda *args, **kwargs: [constant_op.constant(-15.)])
|
lambda *args, **kwargs: [constant_op.constant(-15.)])
|
||||||
z = x + x
|
z = x + x
|
||||||
self.assertAllClose(4., acc.jvp(y))
|
self.assertAllClose(4., acc.jvp(y))
|
||||||
self.assertAllClose(-15., acc.jvp(z))
|
self.assertAllClose(-15., acc.jvp(z))
|
||||||
finally:
|
finally:
|
||||||
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(previous_fn)
|
pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
|
||||||
|
|
||||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
def testFunctionCacheLimited(self):
|
def testFunctionCacheLimited(self):
|
||||||
@ -738,19 +738,19 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
|
with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
|
||||||
with backprop.GradientTape() as tape:
|
with backprop.GradientTape() as tape:
|
||||||
self.assertFalse(tape_lib.should_record_backprop([c]))
|
self.assertFalse(tape_lib.should_record_backprop([c]))
|
||||||
self.assertEqual(
|
self.assertEqual(1,
|
||||||
1, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||||
tape.watch(c)
|
tape.watch(c)
|
||||||
self.assertEqual(
|
self.assertEqual(2,
|
||||||
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||||
self.assertTrue(tape_lib.should_record_backprop([c]))
|
self.assertTrue(tape_lib.should_record_backprop([c]))
|
||||||
with tape_lib.stop_recording():
|
with tape_lib.stop_recording():
|
||||||
self.assertEqual(
|
self.assertEqual(0,
|
||||||
0, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||||
self.assertFalse(tape_lib.should_record_backprop([c]))
|
self.assertFalse(tape_lib.should_record_backprop([c]))
|
||||||
d = c * 2.
|
d = c * 2.
|
||||||
self.assertEqual(
|
self.assertEqual(2,
|
||||||
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||||
self.assertTrue(tape_lib.should_record_backprop([c]))
|
self.assertTrue(tape_lib.should_record_backprop([c]))
|
||||||
self.assertFalse(tape_lib.should_record_backprop([d]))
|
self.assertFalse(tape_lib.should_record_backprop([d]))
|
||||||
self.assertIsNone(acc.jvp(d))
|
self.assertIsNone(acc.jvp(d))
|
||||||
|
@ -24,7 +24,7 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
|
|
||||||
|
|
||||||
class TangentInfo(
|
class TangentInfo(
|
||||||
@ -54,8 +54,7 @@ def pack_tangents(tensors):
|
|||||||
tangents: A flat list of Tensors. Best interpreted as a sequence to be
|
tangents: A flat list of Tensors. Best interpreted as a sequence to be
|
||||||
appended to `tensors`.
|
appended to `tensors`.
|
||||||
"""
|
"""
|
||||||
return TangentInfo(
|
return TangentInfo(*pywrap_tfe.TFE_Py_PackJVPs(tensors))
|
||||||
*pywrap_tensorflow.TFE_Py_PackJVPs(tensors))
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -73,7 +72,7 @@ def push_forwardprop_state():
|
|||||||
None (used for its side effect).
|
None (used for its side effect).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState()
|
pywrap_tfe.TFE_Py_ForwardAccumulatorPushState()
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState()
|
pywrap_tfe.TFE_Py_ForwardAccumulatorPopState()
|
||||||
|
@ -32,8 +32,9 @@ from six.moves import map
|
|||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import function_pb2
|
from tensorflow.core.framework import function_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python import _pywrap_utils
|
from tensorflow.python import _pywrap_utils
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import backprop_util
|
from tensorflow.python.eager import backprop_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -1098,7 +1099,7 @@ class _TapeGradientFunctions(object):
|
|||||||
forward_function.signature.name,
|
forward_function.signature.name,
|
||||||
forward_outputs, forward_inputs, py_backward, None)
|
forward_outputs, forward_inputs, py_backward, None)
|
||||||
output_indices, output_tangents = (
|
output_indices, output_tangents = (
|
||||||
pywrap_tensorflow.TFE_Py_PackJVPs(forward_outputs))
|
pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
|
||||||
output_tangents = [forward_wrapper_graph.capture(t)
|
output_tangents = [forward_wrapper_graph.capture(t)
|
||||||
for t in output_tangents]
|
for t in output_tangents]
|
||||||
return _ForwardWrapper(
|
return _ForwardWrapper(
|
||||||
@ -1732,7 +1733,7 @@ class ConcreteFunction(object):
|
|||||||
"Tensor." % (self._func_graph.name, i, str(arg)))
|
"Tensor." % (self._func_graph.name, i, str(arg)))
|
||||||
args = tensor_inputs + captured_inputs
|
args = tensor_inputs + captured_inputs
|
||||||
possible_gradient_type = (
|
possible_gradient_type = (
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
|
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
|
||||||
if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
|
if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
|
||||||
and executing_eagerly):
|
and executing_eagerly):
|
||||||
# No tape is watching; skip to running the function.
|
# No tape is watching; skip to running the function.
|
||||||
@ -2552,8 +2553,8 @@ class Function(object):
|
|||||||
"""Computes the cache key given inputs and execution context."""
|
"""Computes the cache key given inputs and execution context."""
|
||||||
if self.input_signature is None:
|
if self.input_signature is None:
|
||||||
inputs = (args, kwargs) if kwargs else args
|
inputs = (args, kwargs) if kwargs else args
|
||||||
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(
|
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
|
||||||
inputs, include_tensor_ranks_only)
|
include_tensor_ranks_only)
|
||||||
else:
|
else:
|
||||||
del args, kwargs
|
del args, kwargs
|
||||||
assert not include_tensor_ranks_only
|
assert not include_tensor_ranks_only
|
||||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ def imperative_grad(tape,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
|
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
|
||||||
|
|
||||||
return pywrap_tensorflow.TFE_Py_TapeGradient(
|
return pywrap_tfe.TFE_Py_TapeGradient(
|
||||||
tape._tape, # pylint: disable=protected-access
|
tape._tape, # pylint: disable=protected-access
|
||||||
target,
|
target,
|
||||||
sources,
|
sources,
|
||||||
|
@ -21,80 +21,80 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
|
|
||||||
from tensorflow.core.framework import summary_pb2
|
from tensorflow.core.framework import summary_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.framework import c_api_util
|
from tensorflow.python.eager import eager_util as c_api_util
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
_MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell')
|
_MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell')
|
||||||
_counter_methods = [
|
_counter_methods = [
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewCounter0,
|
create=pywrap_tfe.TFE_MonitoringNewCounter0,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter0,
|
delete=pywrap_tfe.TFE_MonitoringDeleteCounter0,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter0),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter0),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewCounter1,
|
create=pywrap_tfe.TFE_MonitoringNewCounter1,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter1,
|
delete=pywrap_tfe.TFE_MonitoringDeleteCounter1,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter1),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter1),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewCounter2,
|
create=pywrap_tfe.TFE_MonitoringNewCounter2,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter2,
|
delete=pywrap_tfe.TFE_MonitoringDeleteCounter2,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter2),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter2),
|
||||||
]
|
]
|
||||||
_int_gauge_methods = [
|
_int_gauge_methods = [
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge0,
|
create=pywrap_tfe.TFE_MonitoringNewIntGauge0,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge0,
|
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge0,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge0),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge0),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge1,
|
create=pywrap_tfe.TFE_MonitoringNewIntGauge1,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge1,
|
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge1,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge1),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge1),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge2,
|
create=pywrap_tfe.TFE_MonitoringNewIntGauge2,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge2,
|
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge2,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge2),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge2),
|
||||||
]
|
]
|
||||||
_string_gauge_methods = [
|
_string_gauge_methods = [
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge0,
|
create=pywrap_tfe.TFE_MonitoringNewStringGauge0,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge0,
|
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge0,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge0),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge0),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge1,
|
create=pywrap_tfe.TFE_MonitoringNewStringGauge1,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge1,
|
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge1,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge1),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge1),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge2,
|
create=pywrap_tfe.TFE_MonitoringNewStringGauge2,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge2,
|
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge2,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge2),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge2),
|
||||||
]
|
]
|
||||||
_bool_gauge_methods = [
|
_bool_gauge_methods = [
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge0,
|
create=pywrap_tfe.TFE_MonitoringNewBoolGauge0,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge0,
|
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge0,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge0),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge0),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge1,
|
create=pywrap_tfe.TFE_MonitoringNewBoolGauge1,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge1,
|
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge1,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge1),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge1),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge2,
|
create=pywrap_tfe.TFE_MonitoringNewBoolGauge2,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge2,
|
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge2,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge2),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge2),
|
||||||
]
|
]
|
||||||
_sampler_methods = [
|
_sampler_methods = [
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewSampler0,
|
create=pywrap_tfe.TFE_MonitoringNewSampler0,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler0,
|
delete=pywrap_tfe.TFE_MonitoringDeleteSampler0,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler0),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler0),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewSampler1,
|
create=pywrap_tfe.TFE_MonitoringNewSampler1,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler1,
|
delete=pywrap_tfe.TFE_MonitoringDeleteSampler1,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler1),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler1),
|
||||||
_MetricMethod(
|
_MetricMethod(
|
||||||
create=pywrap_tensorflow.TFE_MonitoringNewSampler2,
|
create=pywrap_tfe.TFE_MonitoringNewSampler2,
|
||||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler2,
|
delete=pywrap_tfe.TFE_MonitoringDeleteSampler2,
|
||||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler2),
|
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler2),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -156,11 +156,11 @@ class CounterCell(object):
|
|||||||
Args:
|
Args:
|
||||||
value: non-negative value.
|
value: non-negative value.
|
||||||
"""
|
"""
|
||||||
pywrap_tensorflow.TFE_MonitoringCounterCellIncrementBy(self._cell, value)
|
pywrap_tfe.TFE_MonitoringCounterCellIncrementBy(self._cell, value)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
"""Retrieves the current value."""
|
"""Retrieves the current value."""
|
||||||
return pywrap_tensorflow.TFE_MonitoringCounterCellValue(self._cell)
|
return pywrap_tfe.TFE_MonitoringCounterCellValue(self._cell)
|
||||||
|
|
||||||
|
|
||||||
class Counter(Metric):
|
class Counter(Metric):
|
||||||
@ -204,11 +204,11 @@ class IntGaugeCell(object):
|
|||||||
Args:
|
Args:
|
||||||
value: integer value.
|
value: integer value.
|
||||||
"""
|
"""
|
||||||
pywrap_tensorflow.TFE_MonitoringIntGaugeCellSet(self._cell, value)
|
pywrap_tfe.TFE_MonitoringIntGaugeCellSet(self._cell, value)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
"""Retrieves the current value."""
|
"""Retrieves the current value."""
|
||||||
return pywrap_tensorflow.TFE_MonitoringIntGaugeCellValue(self._cell)
|
return pywrap_tfe.TFE_MonitoringIntGaugeCellValue(self._cell)
|
||||||
|
|
||||||
|
|
||||||
class IntGauge(Metric):
|
class IntGauge(Metric):
|
||||||
@ -252,13 +252,13 @@ class StringGaugeCell(object):
|
|||||||
Args:
|
Args:
|
||||||
value: string value.
|
value: string value.
|
||||||
"""
|
"""
|
||||||
pywrap_tensorflow.TFE_MonitoringStringGaugeCellSet(self._cell, value)
|
pywrap_tfe.TFE_MonitoringStringGaugeCellSet(self._cell, value)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
"""Retrieves the current value."""
|
"""Retrieves the current value."""
|
||||||
with c_api_util.tf_buffer() as buffer_:
|
with c_api_util.tf_buffer() as buffer_:
|
||||||
pywrap_tensorflow.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
|
pywrap_tfe.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
|
||||||
value = pywrap_tensorflow.TF_GetBuffer(buffer_).decode('utf-8')
|
value = pywrap_tfe.TF_GetBuffer(buffer_).decode('utf-8')
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@ -303,11 +303,11 @@ class BoolGaugeCell(object):
|
|||||||
Args:
|
Args:
|
||||||
value: bool value.
|
value: bool value.
|
||||||
"""
|
"""
|
||||||
pywrap_tensorflow.TFE_MonitoringBoolGaugeCellSet(self._cell, value)
|
pywrap_tfe.TFE_MonitoringBoolGaugeCellSet(self._cell, value)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
"""Retrieves the current value."""
|
"""Retrieves the current value."""
|
||||||
return pywrap_tensorflow.TFE_MonitoringBoolGaugeCellValue(self._cell)
|
return pywrap_tfe.TFE_MonitoringBoolGaugeCellValue(self._cell)
|
||||||
|
|
||||||
|
|
||||||
class BoolGauge(Metric):
|
class BoolGauge(Metric):
|
||||||
@ -351,7 +351,7 @@ class SamplerCell(object):
|
|||||||
Args:
|
Args:
|
||||||
value: float value.
|
value: float value.
|
||||||
"""
|
"""
|
||||||
pywrap_tensorflow.TFE_MonitoringSamplerCellAdd(self._cell, value)
|
pywrap_tfe.TFE_MonitoringSamplerCellAdd(self._cell, value)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
"""Retrieves the current distribution of samples.
|
"""Retrieves the current distribution of samples.
|
||||||
@ -360,8 +360,8 @@ class SamplerCell(object):
|
|||||||
A HistogramProto describing the distribution of samples.
|
A HistogramProto describing the distribution of samples.
|
||||||
"""
|
"""
|
||||||
with c_api_util.tf_buffer() as buffer_:
|
with c_api_util.tf_buffer() as buffer_:
|
||||||
pywrap_tensorflow.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
|
pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
|
||||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||||
histogram_proto = summary_pb2.HistogramProto()
|
histogram_proto = summary_pb2.HistogramProto()
|
||||||
histogram_proto.ParseFromString(compat.as_bytes(proto_data))
|
histogram_proto.ParseFromString(compat.as_bytes(proto_data))
|
||||||
return histogram_proto
|
return histogram_proto
|
||||||
@ -379,7 +379,7 @@ class Buckets(object):
|
|||||||
self.buckets = buckets
|
self.buckets = buckets
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
pywrap_tensorflow.TFE_MonitoringDeleteBuckets(self.buckets)
|
pywrap_tfe.TFE_MonitoringDeleteBuckets(self.buckets)
|
||||||
|
|
||||||
|
|
||||||
class ExponentialBuckets(Buckets):
|
class ExponentialBuckets(Buckets):
|
||||||
@ -399,8 +399,8 @@ class ExponentialBuckets(Buckets):
|
|||||||
bucket_count: integer
|
bucket_count: integer
|
||||||
"""
|
"""
|
||||||
super(ExponentialBuckets, self).__init__(
|
super(ExponentialBuckets, self).__init__(
|
||||||
pywrap_tensorflow.TFE_MonitoringNewExponentialBuckets(
|
pywrap_tfe.TFE_MonitoringNewExponentialBuckets(scale, growth_factor,
|
||||||
scale, growth_factor, bucket_count))
|
bucket_count))
|
||||||
|
|
||||||
|
|
||||||
class Sampler(Metric):
|
class Sampler(Metric):
|
||||||
|
@ -39,9 +39,9 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
|
|
||||||
from tensorflow.python import _pywrap_events_writer
|
from tensorflow.python import _pywrap_events_writer
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import c_api_util
|
from tensorflow.python.eager import eager_util as c_api_util
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -74,8 +74,8 @@ def start():
|
|||||||
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
||||||
if context.default_execution_mode == context.EAGER_MODE:
|
if context.default_execution_mode == context.EAGER_MODE:
|
||||||
context.ensure_initialized()
|
context.ensure_initialized()
|
||||||
_profiler = pywrap_tensorflow.TFE_NewProfiler()
|
_profiler = pywrap_tfe.TFE_NewProfiler()
|
||||||
if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler):
|
if not pywrap_tfe.TFE_ProfilerIsOk(_profiler):
|
||||||
logging.warning('Another profiler session is running which is probably '
|
logging.warning('Another profiler session is running which is probably '
|
||||||
'created by profiler server. Please avoid using profiler '
|
'created by profiler server. Please avoid using profiler '
|
||||||
'server and profiler APIs at the same time.')
|
'server and profiler APIs at the same time.')
|
||||||
@ -100,11 +100,9 @@ def stop():
|
|||||||
if context.default_execution_mode == context.EAGER_MODE:
|
if context.default_execution_mode == context.EAGER_MODE:
|
||||||
context.context().executor.wait()
|
context.context().executor.wait()
|
||||||
with c_api_util.tf_buffer() as buffer_:
|
with c_api_util.tf_buffer() as buffer_:
|
||||||
pywrap_tensorflow.TFE_ProfilerSerializeToString(
|
pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
|
||||||
_profiler,
|
result = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||||
buffer_)
|
pywrap_tfe.TFE_DeleteProfiler(_profiler)
|
||||||
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
|
||||||
pywrap_tensorflow.TFE_DeleteProfiler(_profiler)
|
|
||||||
_profiler = None
|
_profiler = None
|
||||||
_run_num += 1
|
_run_num += 1
|
||||||
return result
|
return result
|
||||||
@ -159,7 +157,7 @@ def start_profiler_server(port):
|
|||||||
"""
|
"""
|
||||||
if context.default_execution_mode == context.EAGER_MODE:
|
if context.default_execution_mode == context.EAGER_MODE:
|
||||||
context.ensure_initialized()
|
context.ensure_initialized()
|
||||||
pywrap_tensorflow.TFE_StartProfilerServer(port)
|
pywrap_tfe.TFE_StartProfilerServer(port)
|
||||||
|
|
||||||
|
|
||||||
class Profiler(object):
|
class Profiler(object):
|
||||||
|
@ -18,8 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.framework import c_api_util
|
from tensorflow.python.eager import eager_util as c_api_util
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ def start_tracing(service_addr,
|
|||||||
Raises:
|
Raises:
|
||||||
UnavailableError: If no trace event is collected.
|
UnavailableError: If no trace event is collected.
|
||||||
"""
|
"""
|
||||||
if not pywrap_tensorflow.TFE_ProfilerClientStartTracing(
|
if not pywrap_tfe.TFE_ProfilerClientStartTracing(
|
||||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
||||||
num_tracing_attempts):
|
num_tracing_attempts):
|
||||||
raise errors.UnavailableError(None, None, 'No trace event is collected.')
|
raise errors.UnavailableError(None, None, 'No trace event is collected.')
|
||||||
@ -71,7 +71,7 @@ def monitor(service_addr,
|
|||||||
A string of monitoring output.
|
A string of monitoring output.
|
||||||
"""
|
"""
|
||||||
with c_api_util.tf_buffer() as buffer_:
|
with c_api_util.tf_buffer() as buffer_:
|
||||||
pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms,
|
pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms,
|
||||||
monitoring_level,
|
monitoring_level, display_timestamp,
|
||||||
display_timestamp, buffer_)
|
buffer_)
|
||||||
return pywrap_tensorflow.TF_GetBuffer(buffer_)
|
return pywrap_tfe.TF_GetBuffer(buffer_)
|
||||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
@ -54,14 +54,16 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.matmul(a_2_by_2, b_2_by_2),
|
math_ops.matmul(a_2_by_2, b_2_by_2),
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2,
|
"MatMul", None, None, a_2_by_2,
|
||||||
b_2_by_2, "transpose_a", False, "transpose_b", False))
|
b_2_by_2, "transpose_a", False,
|
||||||
|
"transpose_b", False))
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
|
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_100_by_784,
|
"MatMul", None, None, a_100_by_784,
|
||||||
b_100_by_784, "transpose_a", False, "transpose_b", True))
|
b_100_by_784, "transpose_a", False,
|
||||||
|
"transpose_b", True))
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@test_util.assert_no_garbage_created
|
@test_util.assert_no_garbage_created
|
||||||
@ -71,12 +73,14 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||||
x = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a",
|
"MatMul", None, None, m, m,
|
||||||
False, "transpose_b", False)
|
"transpose_a", False, "transpose_b",
|
||||||
y = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
False)
|
||||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, a_2_by_2,
|
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
"transpose_a", False, "transpose_b", False)
|
"MatMul", None, None, a_2_by_2,
|
||||||
|
a_2_by_2, "transpose_a", False,
|
||||||
|
"transpose_b", False)
|
||||||
|
|
||||||
self.assertAllEqual(x, y)
|
self.assertAllEqual(x, y)
|
||||||
|
|
||||||
@ -89,9 +93,10 @@ class Tests(test.TestCase):
|
|||||||
with backprop.GradientTape(persistent=True) as tape:
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||||
tape.watch(a_2_by_2)
|
tape.watch(a_2_by_2)
|
||||||
z = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2,
|
"MatMul", None, None, a_2_by_2,
|
||||||
a_2_by_2, "transpose_a", False, "transpose_b", False)
|
a_2_by_2, "transpose_a", False,
|
||||||
|
"transpose_b", False)
|
||||||
dz_dy = tape.gradient(z, [a_2_by_2])[0]
|
dz_dy = tape.gradient(z, [a_2_by_2])[0]
|
||||||
self.assertAllEqual(dz_dy.numpy(),
|
self.assertAllEqual(dz_dy.numpy(),
|
||||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||||
@ -106,9 +111,10 @@ class Tests(test.TestCase):
|
|||||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||||
tape.watch(m)
|
tape.watch(m)
|
||||||
z = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "MatMul", None, None, m, m,
|
"MatMul", None, None, m, m,
|
||||||
"transpose_a", False, "transpose_b", False)
|
"transpose_a", False, "transpose_b",
|
||||||
|
False)
|
||||||
dz_dy = tape.gradient(z, [m])[0]
|
dz_dy = tape.gradient(z, [m])[0]
|
||||||
self.assertAllEqual(dz_dy.numpy(),
|
self.assertAllEqual(dz_dy.numpy(),
|
||||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||||
@ -125,9 +131,8 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.add_n([a_2_by_2, b_2_by_2]),
|
math_ops.add_n([a_2_by_2, b_2_by_2]),
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN",
|
||||||
"AddN", None, None,
|
None, None, [a_2_by_2, b_2_by_2]))
|
||||||
[a_2_by_2, b_2_by_2]))
|
|
||||||
|
|
||||||
# Tests homogeneous list op
|
# Tests homogeneous list op
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@ -142,9 +147,9 @@ class Tests(test.TestCase):
|
|||||||
with backprop.GradientTape(persistent=True) as tape:
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
tape.watch(a_2_by_2)
|
tape.watch(a_2_by_2)
|
||||||
tape.watch(b_2_by_2)
|
tape.watch(b_2_by_2)
|
||||||
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "AddN", None, None,
|
"AddN", None, None,
|
||||||
[a_2_by_2, b_2_by_2])
|
[a_2_by_2, b_2_by_2])
|
||||||
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
|
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
|
||||||
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
|
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
|
||||||
dz2_dy = tape.gradient(z2, [a_2_by_2])[0]
|
dz2_dy = tape.gradient(z2, [a_2_by_2])[0]
|
||||||
@ -162,9 +167,9 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
array_ops.identity_n([a_2_by_2, b_2_by_2]),
|
array_ops.identity_n([a_2_by_2, b_2_by_2]),
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
"IdentityN", None, None,
|
"IdentityN", None, None,
|
||||||
[a_2_by_2, b_2_by_2]))
|
[a_2_by_2, b_2_by_2]))
|
||||||
|
|
||||||
# Tests heterogeneous list op
|
# Tests heterogeneous list op
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@ -179,9 +184,9 @@ class Tests(test.TestCase):
|
|||||||
with backprop.GradientTape(persistent=True) as tape:
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
tape.watch(a_2_by_2)
|
tape.watch(a_2_by_2)
|
||||||
tape.watch(b_2_by_2)
|
tape.watch(b_2_by_2)
|
||||||
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||||
ctx._handle, ctx.device_name, "IdentityN", None, None,
|
"IdentityN", None, None,
|
||||||
[a_2_by_2, b_2_by_2])
|
[a_2_by_2, b_2_by_2])
|
||||||
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
|
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
|
||||||
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
|
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
|
||||||
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
|
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
|
||||||
@ -201,19 +206,18 @@ class Tests(test.TestCase):
|
|||||||
# Not enough base params
|
# Not enough base params
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
"at least 5 items in the input tuple"):
|
"at least 5 items in the input tuple"):
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity")
|
||||||
"Identity")
|
|
||||||
|
|
||||||
# Not enough inputs
|
# Not enough inputs
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
"Expected to be at least 6, was 5"):
|
"Expected to be at least 6, was 5"):
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx_handle,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity",
|
||||||
"Identity", None, [])
|
None, [])
|
||||||
|
|
||||||
# Bad type
|
# Bad type
|
||||||
with self.assertRaisesRegexp(TypeError, "expected a string for op_name"):
|
with self.assertRaisesRegexp(TypeError, "expected a string for op_name"):
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle,
|
||||||
ctx_handle, None, [], a_2_by_2)
|
None, [], a_2_by_2)
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@test_util.assert_no_garbage_created
|
@test_util.assert_no_garbage_created
|
||||||
@ -225,9 +229,9 @@ class Tests(test.TestCase):
|
|||||||
|
|
||||||
ctx_handle = ctx._handle
|
ctx_handle = ctx._handle
|
||||||
with self.assertRaises(core._FallbackException):
|
with self.assertRaises(core._FallbackException):
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split",
|
||||||
"Split", None, None, split_dim,
|
None, None, split_dim, value,
|
||||||
value, "num_split", -1)
|
"num_split", -1)
|
||||||
|
|
||||||
@test_util.assert_no_new_tensors
|
@test_util.assert_no_new_tensors
|
||||||
@test_util.assert_no_garbage_created
|
@test_util.assert_no_garbage_created
|
||||||
@ -266,10 +270,9 @@ class Tests(test.TestCase):
|
|||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
ctx.ensure_initialized()
|
ctx.ensure_initialized()
|
||||||
with self.assertRaises(core._FallbackException):
|
with self.assertRaises(core._FallbackException):
|
||||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul",
|
||||||
"MatMul", None, None, m, m,
|
None, None, m, m, "transpose_a", False,
|
||||||
"transpose_a", False,
|
"transpose_b", False)
|
||||||
"transpose_b", False)
|
|
||||||
|
|
||||||
def testOpDefDefaultType(self):
|
def testOpDefDefaultType(self):
|
||||||
im = np.random.randint(
|
im = np.random.randint(
|
||||||
|
@ -22,7 +22,7 @@ import copy
|
|||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -127,7 +127,7 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
|||||||
|
|
||||||
# Automatically add local job, if not part of the cluster spec.
|
# Automatically add local job, if not part of the cluster spec.
|
||||||
if job_name not in cluster_spec.jobs:
|
if job_name not in cluster_spec.jobs:
|
||||||
local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
|
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
||||||
job_def = cluster_def.job.add()
|
job_def = cluster_def.job.add()
|
||||||
job_def.name = job_name
|
job_def.name = job_name
|
||||||
# TODO(fishx): Update this to make sure remote worker has valid ip address
|
# TODO(fishx): Update this to make sure remote worker has valid ip address
|
||||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
|
|
||||||
# There is a circular dependency between this, ops.py, and
|
# There is a circular dependency between this, ops.py, and
|
||||||
@ -39,24 +39,23 @@ class Tape(object):
|
|||||||
self._tape = tape
|
self._tape = tape
|
||||||
|
|
||||||
def watched_variables(self):
|
def watched_variables(self):
|
||||||
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
|
return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
|
||||||
|
|
||||||
|
|
||||||
def push_new_tape(persistent=False, watch_accessed_variables=True):
|
def push_new_tape(persistent=False, watch_accessed_variables=True):
|
||||||
"""Pushes a new tape onto the tape stack."""
|
"""Pushes a new tape onto the tape stack."""
|
||||||
tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
|
tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
|
||||||
watch_accessed_variables)
|
|
||||||
return Tape(tape)
|
return Tape(tape)
|
||||||
|
|
||||||
|
|
||||||
def push_tape(tape):
|
def push_tape(tape):
|
||||||
"""Pushes an existing tape onto the tape stack."""
|
"""Pushes an existing tape onto the tape stack."""
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
|
pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def watch(tape, tensor):
|
def watch(tape, tensor):
|
||||||
"""Marks this tensor to be watched by the given tape."""
|
"""Marks this tensor to be watched by the given tape."""
|
||||||
pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
|
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def watch_variable(tape, variable):
|
def watch_variable(tape, variable):
|
||||||
@ -68,7 +67,7 @@ def watch_variable(tape, variable):
|
|||||||
else:
|
else:
|
||||||
variables = strategy.experimental_local_results(variable)
|
variables = strategy.experimental_local_results(variable)
|
||||||
for var in variables:
|
for var in variables:
|
||||||
pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
|
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def variable_accessed(variable):
|
def variable_accessed(variable):
|
||||||
@ -84,7 +83,7 @@ def variable_accessed(variable):
|
|||||||
else:
|
else:
|
||||||
variables = strategy.experimental_local_results(variable)
|
variables = strategy.experimental_local_results(variable)
|
||||||
for var in variables:
|
for var in variables:
|
||||||
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
|
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||||
|
|
||||||
|
|
||||||
def variables_accessed(variables):
|
def variables_accessed(variables):
|
||||||
@ -107,25 +106,25 @@ def variables_accessed(variables):
|
|||||||
accessed.extend(strategy.experimental_local_results(variable))
|
accessed.extend(strategy.experimental_local_results(variable))
|
||||||
|
|
||||||
for var in accessed:
|
for var in accessed:
|
||||||
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
|
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||||
|
|
||||||
|
|
||||||
def pop_tape(tape):
|
def pop_tape(tape):
|
||||||
"""Pops the given tape in the stack."""
|
"""Pops the given tape in the stack."""
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
|
pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def stop_recording():
|
def stop_recording():
|
||||||
"""Stop all gradient recording (backprop and forwardprop)."""
|
"""Stop all gradient recording (backprop and forwardprop)."""
|
||||||
is_stopped = pywrap_tensorflow.TFE_Py_TapeSetIsStopped()
|
is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped()
|
||||||
try:
|
try:
|
||||||
if not is_stopped:
|
if not is_stopped:
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetStopOnThread()
|
pywrap_tfe.TFE_Py_TapeSetStopOnThread()
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
if not is_stopped:
|
if not is_stopped:
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread()
|
pywrap_tfe.TFE_Py_TapeSetRestartOnThread()
|
||||||
|
|
||||||
|
|
||||||
def should_record_backprop(tensors):
|
def should_record_backprop(tensors):
|
||||||
@ -139,22 +138,23 @@ def should_record_backprop(tensors):
|
|||||||
Returns:
|
Returns:
|
||||||
Boolean, whether any tape watches any of `tensors`.
|
Boolean, whether any tape watches any of `tensors`.
|
||||||
"""
|
"""
|
||||||
return pywrap_tensorflow.TFE_Py_TapeSetShouldRecordBackprop(tensors)
|
return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors)
|
||||||
|
|
||||||
|
|
||||||
def record_operation(op_type, output_tensors, input_tensors, backward_function,
|
def record_operation(op_type, output_tensors, input_tensors, backward_function,
|
||||||
forward_function=None):
|
forward_function=None):
|
||||||
"""Records the operation on all tapes in the stack."""
|
"""Records the operation on all tapes in the stack."""
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetRecordOperation(
|
pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors,
|
||||||
op_type, output_tensors, input_tensors, backward_function,
|
input_tensors, backward_function,
|
||||||
forward_function)
|
forward_function)
|
||||||
|
|
||||||
|
|
||||||
def record_operation_backprop_only(op_type, output_tensors, input_tensors,
|
def record_operation_backprop_only(op_type, output_tensors, input_tensors,
|
||||||
backward_function):
|
backward_function):
|
||||||
"""Records the operation on all backward tapes in the stack."""
|
"""Records the operation on all backward tapes in the stack."""
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationBackprop(
|
pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors,
|
||||||
op_type, output_tensors, input_tensors, backward_function)
|
input_tensors,
|
||||||
|
backward_function)
|
||||||
|
|
||||||
|
|
||||||
def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
|
def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
|
||||||
@ -174,16 +174,16 @@ def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
|
|||||||
Typically these will have come from TFE_Py_PackForwardGradients. May be
|
Typically these will have come from TFE_Py_PackForwardGradients. May be
|
||||||
None or an empty sequence if there are no JVP outputs from the operation.
|
None or an empty sequence if there are no JVP outputs from the operation.
|
||||||
"""
|
"""
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationForwardprop(
|
pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop(
|
||||||
op_type, output_tensors, input_tensors, backward_function,
|
op_type, output_tensors, input_tensors, backward_function,
|
||||||
forwardprop_output_indices)
|
forwardprop_output_indices)
|
||||||
|
|
||||||
|
|
||||||
def delete_trace(tensor_id):
|
def delete_trace(tensor_id):
|
||||||
"""Deletes traces for this Tensor from all tapes in the stack."""
|
"""Deletes traces for this Tensor from all tapes in the stack."""
|
||||||
pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id)
|
pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id)
|
||||||
|
|
||||||
|
|
||||||
def could_possibly_record():
|
def could_possibly_record():
|
||||||
"""Returns True if any tape is active."""
|
"""Returns True if any tape is active."""
|
||||||
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
|
return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
|
||||||
|
@ -26,7 +26,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
@ -435,14 +435,14 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
|
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
|
||||||
t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
|
t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
|
||||||
|
|
||||||
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
|
r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
|
||||||
self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
|
self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
|
||||||
|
|
||||||
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
|
r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
|
||||||
self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
|
self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
|
||||||
|
|
||||||
def testEmptyTensorList(self):
|
def testEmptyTensorList(self):
|
||||||
a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0)
|
a = pywrap_tfe.TFE_Py_TensorShapeSlice([], 0)
|
||||||
self.assertTrue(isinstance(a, ops.EagerTensor))
|
self.assertTrue(isinstance(a, ops.EagerTensor))
|
||||||
self.assertEqual(0, a.numpy().size)
|
self.assertEqual(0, a.numpy().size)
|
||||||
|
|
||||||
@ -452,12 +452,12 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
TypeError,
|
TypeError,
|
||||||
r"Expected a list of EagerTensors but element 1 has type \"str\""):
|
r"Expected a list of EagerTensors but element 1 has type \"str\""):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
TypeError,
|
TypeError,
|
||||||
r"Expected a list of EagerTensors but element 0 has type \"int\""):
|
r"Expected a list of EagerTensors but element 0 has type \"int\""):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([2, t1], 0)
|
||||||
|
|
||||||
def testTensorListNotList(self):
|
def testTensorListNotList(self):
|
||||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||||
@ -465,7 +465,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
TypeError,
|
TypeError,
|
||||||
r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
|
r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
pywrap_tfe.TFE_Py_TensorShapeSlice(t1, -2)
|
||||||
|
|
||||||
def testNegativeSliceDim(self):
|
def testNegativeSliceDim(self):
|
||||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||||
@ -473,7 +473,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
r"Slice dimension must be non-negative. Got -2"):
|
r"Slice dimension must be non-negative. Got -2"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([t1], -2)
|
||||||
|
|
||||||
def testUnicode(self):
|
def testUnicode(self):
|
||||||
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
|
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
|
||||||
@ -493,31 +493,31 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
IndexError,
|
IndexError,
|
||||||
r"Slice dimension \(2\) must be smaller than rank of all tensors, "
|
r"Slice dimension \(2\) must be smaller than rank of all tensors, "
|
||||||
"but tensor at index 0 has rank 2"):
|
"but tensor at index 0 has rank 2"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([t1], 2)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
IndexError,
|
IndexError,
|
||||||
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||||
"but tensor at index 0 has rank 1"):
|
"but tensor at index 0 has rank 1"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([t2], 1)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
IndexError,
|
IndexError,
|
||||||
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||||
"but tensor at index 1 has rank 1"):
|
"but tensor at index 1 has rank 1"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2], 1)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
IndexError,
|
IndexError,
|
||||||
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||||
"but tensor at index 0 has rank 0"):
|
"but tensor at index 0 has rank 0"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([t3], 0)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
IndexError,
|
IndexError,
|
||||||
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||||
"but tensor at index 2 has rank 0"):
|
"but tensor at index 2 has rank 0"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
pywrap_tfe.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
||||||
|
|
||||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
def testTensorDir(self):
|
def testTensorDir(self):
|
||||||
|
@ -36,7 +36,12 @@ from tensorflow.core.framework import node_def_pb2
|
|||||||
from tensorflow.core.framework import op_def_pb2
|
from tensorflow.core.framework import op_def_pb2
|
||||||
from tensorflow.core.framework import versions_pb2
|
from tensorflow.core.framework import versions_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
# pywrap_tensorflow must be imported first to avoid profobuf issues.
|
||||||
|
# (b/143110113)
|
||||||
|
# pylint: disable=invalid-import-order,g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow as c_api
|
from tensorflow.python import pywrap_tensorflow as c_api
|
||||||
|
from tensorflow.python import pywrap_tfe as c_api_new
|
||||||
|
# pylint: enable=invalid-import-order,g-bad-import-order
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
@ -249,7 +254,7 @@ def register_dense_tensor_like_type(tensor_type):
|
|||||||
|
|
||||||
def uid():
|
def uid():
|
||||||
"""A unique (within this program execution) integer."""
|
"""A unique (within this program execution) integer."""
|
||||||
return c_api.TFE_Py_UID()
|
return c_api_new.TFE_Py_UID()
|
||||||
|
|
||||||
|
|
||||||
def numpy_text(tensor, is_repr=False):
|
def numpy_text(tensor, is_repr=False):
|
||||||
@ -1135,7 +1140,7 @@ class _EagerTensorBase(Tensor):
|
|||||||
|
|
||||||
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
|
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
|
||||||
# registers it with the current module.
|
# registers it with the current module.
|
||||||
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
EagerTensor = c_api_new.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||||
|
|
||||||
|
|
||||||
register_dense_tensor_like_type(Tensor)
|
register_dense_tensor_like_type(Tensor)
|
||||||
|
@ -789,8 +789,7 @@ void GenEagerPythonOp::AddEagerFastPathExecute() {
|
|||||||
|
|
||||||
strings::StrAppend(&result_, " try:\n");
|
strings::StrAppend(&result_, " try:\n");
|
||||||
strings::StrAppend(
|
strings::StrAppend(
|
||||||
&result_, " ",
|
&result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
|
||||||
"_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
|
|
||||||
WordWrap(strings::StrCat(" "),
|
WordWrap(strings::StrCat(" "),
|
||||||
strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
|
strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
|
||||||
"\n");
|
"\n");
|
||||||
@ -1000,7 +999,7 @@ This file is MACHINE GENERATED! Do not edit.
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe as pywrap_tfe
|
||||||
from tensorflow.python.eager import context as _context
|
from tensorflow.python.eager import context as _context
|
||||||
from tensorflow.python.eager import core as _core
|
from tensorflow.python.eager import core as _core
|
||||||
from tensorflow.python.eager import execute as _execute
|
from tensorflow.python.eager import execute as _execute
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -115,8 +116,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
|||||||
non_neg_concat_dim = (
|
non_neg_concat_dim = (
|
||||||
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
|
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
|
||||||
# All inputs are guaranteed to be EagerTensors in eager mode
|
# All inputs are guaranteed to be EagerTensors in eager mode
|
||||||
sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
|
sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
|
||||||
non_neg_concat_dim)
|
non_neg_concat_dim)
|
||||||
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
|
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
|
||||||
else:
|
else:
|
||||||
if constant_op.is_constant(concat_dim):
|
if constant_op.is_constant(concat_dim):
|
||||||
|
@ -26,7 +26,7 @@ import sys
|
|||||||
from absl import logging
|
from absl import logging
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -45,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
# Register printing to the cell output if we are in a Colab or Jupyter Notebook.
|
# Register printing to the cell output if we are in a Colab or Jupyter Notebook.
|
||||||
try:
|
try:
|
||||||
get_ipython() # Exists in an ipython env like Jupyter or Colab
|
get_ipython() # Exists in an ipython env like Jupyter or Colab
|
||||||
pywrap_tensorflow.TFE_Py_EnableInteractivePythonLogging()
|
pywrap_tfe.TFE_Py_EnableInteractivePythonLogging()
|
||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||||
|
|
||||||
using tensorflow::uint64;
|
using tensorflow::uint64;
|
||||||
@ -233,7 +234,50 @@ _COPY_TYPEMAPS(unsigned int, mode_t);
|
|||||||
%define override %enddef
|
%define override %enddef
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
// This was originally included in pywrap_tfe.i, but is used by tf_session.i
|
||||||
%include "tensorflow/c/tf_status.h"
|
%include "tensorflow/c/tf_status.h"
|
||||||
|
%include "tensorflow/c/tf_datatype.h"
|
||||||
|
|
||||||
|
%typemap(in) (const void* proto) {
|
||||||
|
char* c_string;
|
||||||
|
Py_ssize_t py_size;
|
||||||
|
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
|
||||||
|
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||||
|
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||||
|
SWIG_fail;
|
||||||
|
}
|
||||||
|
$1 = static_cast<void*>(c_string);
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(in) int64_t {
|
||||||
|
$1 = PyLong_AsLongLong($input);
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(out) TF_DataType {
|
||||||
|
$result = PyInt_FromLong($1);
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(out) int64_t {
|
||||||
|
$result = PyInt_FromLong($1);
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(out) TF_AttrType {
|
||||||
|
$result = PyInt_FromLong($1);
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
|
||||||
|
tmp = 0;
|
||||||
|
$1 = &tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(argout) unsigned char* is_list {
|
||||||
|
if (*$1 == 1) {
|
||||||
|
PyObject* list = PyList_New(1);
|
||||||
|
PyList_SetItem(list, 0, $result);
|
||||||
|
$result = list;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Typemaps to automatically raise a Python exception from bad output TF_Status.
|
// Typemaps to automatically raise a Python exception from bad output TF_Status.
|
||||||
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
|
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
|
||||||
|
@ -1,515 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include "tensorflow/python/lib/core/strings.i"
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
|
||||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
|
|
||||||
%include "tensorflow/c/tf_datatype.h"
|
|
||||||
%include "tensorflow/c/tf_status.h"
|
|
||||||
|
|
||||||
%ignoreall;
|
|
||||||
|
|
||||||
%rename("%s") TF_SetXlaEnableLazyCompilation;
|
|
||||||
%rename("%s") TF_SetTfXlaCpuGlobalJit;
|
|
||||||
%rename("%s") TF_SetXlaAutoJitMode;
|
|
||||||
%rename("%s") TF_SetXlaConstantFoldingDisabled;
|
|
||||||
%rename("%s") TF_GetXlaConstantFoldingDisabled;
|
|
||||||
%rename("%s") TF_SetXlaMinClusterSize;
|
|
||||||
%rename("%s") TFE_NewContext;
|
|
||||||
%rename("%s") TFE_DeleteContext;
|
|
||||||
%rename("%s") TFE_ContextListDevices;
|
|
||||||
%rename("%s") TFE_ContextAddFunction;
|
|
||||||
%rename("%s") TFE_ContextAddFunctionDef;
|
|
||||||
%rename("%s") TFE_ContextRemoveFunction;
|
|
||||||
%rename("%s") TFE_ContextHasFunction;
|
|
||||||
%rename("%s") TFE_ContextEnableRunMetadata;
|
|
||||||
%rename("%s") TFE_ContextDisableRunMetadata;
|
|
||||||
%rename("%s") TFE_ContextEnableGraphCollection;
|
|
||||||
%rename("%s") TFE_ContextDisableGraphCollection;
|
|
||||||
%rename("%s") TFE_ContextExportRunMetadata;
|
|
||||||
%rename("%s") TFE_ContextClearCaches;
|
|
||||||
%rename("%s") TFE_ContextGetDevicePlacementPolicy;
|
|
||||||
%rename("%s") TFE_ContextGetMirroringPolicy;
|
|
||||||
%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
|
|
||||||
%rename("%s") TFE_ContextSetThreadLocalMirroringPolicy;
|
|
||||||
%rename("%s") TFE_ContextSetServerDef;
|
|
||||||
%rename("%s") TFE_ContextUpdateServerDef;
|
|
||||||
%rename("%s") TFE_ContextCheckAlive;
|
|
||||||
%rename("%s") TFE_NewExecutor;
|
|
||||||
%rename("%s") TFE_DeleteExecutor;
|
|
||||||
%rename("%s") TFE_ExecutorIsAsync;
|
|
||||||
%rename("%s") TFE_ExecutorWaitForAllPendingNodes;
|
|
||||||
%rename("%s") TFE_ExecutorClearError;
|
|
||||||
%rename("%s") TFE_ContextSetExecutorForThread;
|
|
||||||
%rename("%s") TFE_ContextGetExecutorForThread;
|
|
||||||
%rename("%s") TFE_NewProfiler;
|
|
||||||
%rename("%s") TFE_ProfilerIsOk;
|
|
||||||
%rename("%s") TFE_DeleteProfiler;
|
|
||||||
%rename("%s") TFE_ProfilerSerializeToString;
|
|
||||||
%rename("%s") TFE_StartProfilerServer;
|
|
||||||
%rename("%s") TFE_ProfilerClientStartTracing;
|
|
||||||
%rename("%s") TFE_ProfilerClientMonitor;
|
|
||||||
%rename("%s") TFE_OpNameGetAttrType;
|
|
||||||
%rename("%s") TFE_Py_InitEagerTensor;
|
|
||||||
%rename("%s") TFE_Py_SetEagerTensorProfiler;
|
|
||||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
|
||||||
%rename("%s") TFE_Py_RegisterJVPFunction;
|
|
||||||
%rename("%s") TFE_Py_RegisterGradientFunction;
|
|
||||||
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
|
|
||||||
%rename("%s") TFE_Py_Execute;
|
|
||||||
%rename("%s") TFE_Py_ExecuteCancelable;
|
|
||||||
%rename("%s") TFE_Py_FastPathExecute;
|
|
||||||
%rename("%s") TFE_Py_RecordGradient;
|
|
||||||
%rename("%s") TFE_Py_UID;
|
|
||||||
%rename("%s") TFE_Py_TapeSetNew;
|
|
||||||
%rename("%s") TFE_Py_TapeSetAdd;
|
|
||||||
%rename("%s") TFE_Py_TapeSetRemove;
|
|
||||||
%rename("%s") TFE_Py_TapeSetStopOnThread;
|
|
||||||
%rename("%s") TFE_Py_TapeSetRestartOnThread;
|
|
||||||
%rename("%s") TFE_Py_TapeSetIsStopped;
|
|
||||||
%rename("%s") TFE_Py_TapeSetIsEmpty;
|
|
||||||
%rename("%s") TFE_Py_TapeSetShouldRecordBackprop;
|
|
||||||
%rename("%s") TFE_Py_TapeSetPossibleGradientTypes;
|
|
||||||
%rename("%s") TFE_Py_TapeSetDeleteTrace;
|
|
||||||
%rename("%s") TFE_Py_TapeSetRecordOperation;
|
|
||||||
%rename("%s") TFE_Py_TapeSetRecordOperationBackprop;
|
|
||||||
%rename("%s") TFE_Py_TapeSetRecordOperationForwardprop;
|
|
||||||
%rename("%s") TFE_Py_TapeGradient;
|
|
||||||
%rename("%s") TFE_Py_TapeVariableAccessed;
|
|
||||||
%rename("%s") TFE_Py_TapeWatch;
|
|
||||||
%rename("%s") TFE_Py_TapeWatchVariable;
|
|
||||||
%rename("%s") TFE_Py_TapeWatchedVariables;
|
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorNew;
|
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorSetAdd;
|
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
|
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
|
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
|
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorPushState;
|
|
||||||
%rename("%s") TFE_Py_ForwardAccumulatorPopState;
|
|
||||||
%rename("%s") TFE_Py_PackJVPs;
|
|
||||||
%rename("%s") TFE_NewContextOptions;
|
|
||||||
%rename("%s") TFE_ContextOptionsSetConfig;
|
|
||||||
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
|
||||||
%rename("%s") TFE_ContextOptionsSetMirroringPolicy;
|
|
||||||
%rename("%s") TFE_ContextOptionsSetAsync;
|
|
||||||
%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy;
|
|
||||||
%rename("%s") TFE_DeleteContextOptions;
|
|
||||||
%rename("%s") TFE_Py_TensorShapeSlice;
|
|
||||||
%rename("%s") TFE_Py_TensorShapeOnDevice;
|
|
||||||
%rename("%s") TFE_Py_EnableInteractivePythonLogging;
|
|
||||||
%rename("%s") TFE_Py_SetEagerContext;
|
|
||||||
%rename("%s") TFE_ContextStartStep;
|
|
||||||
%rename("%s") TFE_ContextEndStep;
|
|
||||||
%rename("%s") TFE_Py_RegisterVSpace;
|
|
||||||
%rename("%s") TFE_Py_EncodeArg;
|
|
||||||
%rename("%s") TFE_EnableCollectiveOps;
|
|
||||||
%rename("%s") TF_ListPhysicalDevices;
|
|
||||||
%rename("%s") TF_PickUnusedPortOrDie;
|
|
||||||
%rename("%s") TFE_MonitoringCounterCellIncrementBy;
|
|
||||||
%rename("%s") TFE_MonitoringCounterCellValue;
|
|
||||||
%rename("%s") TFE_MonitoringNewCounter0;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteCounter0;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellCounter0;
|
|
||||||
%rename("%s") TFE_MonitoringNewCounter1;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteCounter1;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellCounter1;
|
|
||||||
%rename("%s") TFE_MonitoringNewCounter2;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteCounter2;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellCounter2;
|
|
||||||
%rename("%s") TFE_MonitoringIntGaugeCellSet;
|
|
||||||
%rename("%s") TFE_MonitoringIntGaugeCellValue;
|
|
||||||
%rename("%s") TFE_MonitoringNewIntGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteIntGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellIntGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringNewIntGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteIntGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellIntGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringNewIntGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteIntGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellIntGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringStringGaugeCellSet;
|
|
||||||
%rename("%s") TFE_MonitoringStringGaugeCellValue;
|
|
||||||
%rename("%s") TFE_MonitoringNewStringGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteStringGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellStringGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringNewStringGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteStringGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellStringGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringNewStringGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteStringGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellStringGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringBoolGaugeCellSet;
|
|
||||||
%rename("%s") TFE_MonitoringBoolGaugeCellValue;
|
|
||||||
%rename("%s") TFE_MonitoringNewBoolGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteBoolGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellBoolGauge0;
|
|
||||||
%rename("%s") TFE_MonitoringNewBoolGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteBoolGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellBoolGauge1;
|
|
||||||
%rename("%s") TFE_MonitoringNewBoolGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteBoolGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellBoolGauge2;
|
|
||||||
%rename("%s") TFE_MonitoringSamplerCellAdd;
|
|
||||||
%rename("%s") TFE_MonitoringSamplerCellValue;
|
|
||||||
%rename("%s") TFE_MonitoringNewExponentialBuckets;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteBuckets;
|
|
||||||
%rename("%s") TFE_MonitoringNewSampler0;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteSampler0;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellSampler0;
|
|
||||||
%rename("%s") TFE_MonitoringNewSampler1;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteSampler1;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellSampler1;
|
|
||||||
%rename("%s") TFE_MonitoringNewSampler2;
|
|
||||||
%rename("%s") TFE_MonitoringDeleteSampler2;
|
|
||||||
%rename("%s") TFE_MonitoringGetCellSampler2;
|
|
||||||
%rename("%s") TFE_NewCancellationManager;
|
|
||||||
%rename("%s") TFE_CancellationManagerIsCancelled;
|
|
||||||
%rename("%s") TFE_CancellationManagerStartCancel;
|
|
||||||
%rename("%s") TFE_DeleteCancellationManager;
|
|
||||||
%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
|
||||||
%rename("%s") TFE_ClearScalarCache;
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
|
||||||
#include "tensorflow/python/util/util.h"
|
|
||||||
#include "tensorflow/c/c_api_experimental.h"
|
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
|
||||||
|
|
||||||
static PyObject* TF_ListPhysicalDevices(TF_Status* status) {
|
|
||||||
std::vector<string> devices;
|
|
||||||
tensorflow::Status s = tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
|
|
||||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
|
||||||
if (!s.ok()) {
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
};
|
|
||||||
PyObject* result = PyList_New(devices.size());
|
|
||||||
int i = 0;
|
|
||||||
for (auto& dev : devices) {
|
|
||||||
PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
|
|
||||||
PyList_SetItem(result, i, dev_obj);
|
|
||||||
++i;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
static PyObject* TF_ListPhysicalDevices(TF_Status* status);
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
|
||||||
|
|
||||||
static PyObject* TFE_ClearScalarCache() {
|
|
||||||
tensorflow::TFE_TensorHandleCache::Get()->Clear();
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
static PyObject* TFE_ClearScalarCache();
|
|
||||||
|
|
||||||
%typemap(in) (const void* proto) {
|
|
||||||
char* c_string;
|
|
||||||
Py_ssize_t py_size;
|
|
||||||
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
|
|
||||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
|
||||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
$1 = static_cast<void*>(c_string);
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(in) int64_t {
|
|
||||||
$1 = PyLong_AsLongLong($input);
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(out) TF_DataType {
|
|
||||||
$result = PyInt_FromLong($1);
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(out) int64_t {
|
|
||||||
$result = PyInt_FromLong($1);
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(out) TF_AttrType {
|
|
||||||
$result = PyInt_FromLong($1);
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
|
|
||||||
tmp = 0;
|
|
||||||
$1 = &tmp;
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(argout) unsigned char* is_list {
|
|
||||||
if (*$1 == 1) {
|
|
||||||
PyObject* list = PyList_New(1);
|
|
||||||
PyList_SetItem(list, 0, $result);
|
|
||||||
$result = list;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* serialized_function_def {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* device_name {
|
|
||||||
if ($input == Py_None) {
|
|
||||||
$1 = nullptr;
|
|
||||||
} else {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* op_name {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* name {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* description {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* label {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* label1 {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* label2 {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* service_addr {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* logdir {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
|
||||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
|
||||||
// Hence the 'const_cast'.
|
|
||||||
%typemap(in) const char* worker_list {
|
|
||||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(in) (TFE_Context*) {
|
|
||||||
$1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr);
|
|
||||||
|
|
||||||
}
|
|
||||||
%typemap(out) (TFE_Context*) {
|
|
||||||
// When the TFE_Context* returned is a nullptr, we expect the status is not
|
|
||||||
// OK. This will raise an error (happens in another typemap).
|
|
||||||
if ($1 != nullptr) {
|
|
||||||
$result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
%rename("%s") TFE_ContextDevicePlacementPolicy;
|
|
||||||
%rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT;
|
|
||||||
%rename("%s") TFE_DEVICE_PLACEMENT_WARN;
|
|
||||||
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT;
|
|
||||||
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32;
|
|
||||||
|
|
||||||
%rename("%s") TFE_ContextMirroringPolicy;
|
|
||||||
%rename("%s") TFE_MIRRORING_NONE;
|
|
||||||
%rename("%s") TFE_MIRRORING_ALL;
|
|
||||||
|
|
||||||
%include "tensorflow/c/eager/c_api.h"
|
|
||||||
|
|
||||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
|
||||||
$1 = &temp;
|
|
||||||
if ($input != Py_None) {
|
|
||||||
if (!PyList_Check($input)) {
|
|
||||||
SWIG_exception_fail(SWIG_TypeError,
|
|
||||||
"must provide a list of Tensors as inputs");
|
|
||||||
}
|
|
||||||
Py_ssize_t len = PyList_Size($input);
|
|
||||||
$1->resize(len);
|
|
||||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
|
||||||
PyObject* elem = PyList_GetItem($input, i);
|
|
||||||
if (!elem) {
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
if (EagerTensor_CheckExact(elem)) {
|
|
||||||
(*$1)[i] = EagerTensor_Handle(elem);
|
|
||||||
} else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
|
|
||||||
// Use equivalent of object.__getattribute__ to get the underlying
|
|
||||||
// tf wrapped EagerTensor (if there is one).
|
|
||||||
tensorflow::Safe_PyObjectPtr tf_should_use_attr(
|
|
||||||
#if PY_MAJOR_VERSION < 3
|
|
||||||
PyString_InternFromString("_tf_should_use_wrapped_value")
|
|
||||||
#else
|
|
||||||
PyUnicode_InternFromString("_tf_should_use_wrapped_value")
|
|
||||||
#endif
|
|
||||||
);
|
|
||||||
tensorflow::Safe_PyObjectPtr value_attr(
|
|
||||||
PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
|
|
||||||
if (value_attr) {
|
|
||||||
// This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
|
|
||||||
(*$1)[i] = EagerTensor_Handle(value_attr.get());
|
|
||||||
} else {
|
|
||||||
// This is a subclass of EagerTensor that we don't support.
|
|
||||||
PyErr_Clear();
|
|
||||||
SWIG_exception_fail(
|
|
||||||
SWIG_TypeError,
|
|
||||||
tensorflow::strings::StrCat(
|
|
||||||
"Saw an object that is an instance of a strict subclass of "
|
|
||||||
"EagerTensor, which is not supported. Item ",
|
|
||||||
i, " is type: ", elem->ob_type->tp_name)
|
|
||||||
.c_str());
|
|
||||||
}
|
|
||||||
} else if (tensorflow::swig::IsTensor(elem)) {
|
|
||||||
// If it isnt an EagerTensor, but is still a Tensor, it must be a graph
|
|
||||||
// tensor.
|
|
||||||
tensorflow::Safe_PyObjectPtr name_attr(
|
|
||||||
PyObject_GetAttrString(elem, "name"));
|
|
||||||
SWIG_exception_fail(
|
|
||||||
SWIG_TypeError,
|
|
||||||
tensorflow::strings::StrCat(
|
|
||||||
"An op outside of the function building code is being passed\n"
|
|
||||||
"a \"Graph\" tensor. It is possible to have Graph tensors\n"
|
|
||||||
"leak out of the function building context by including a\n"
|
|
||||||
"tf.init_scope in your function building code.\n"
|
|
||||||
"For example, the following function will fail:\n",
|
|
||||||
" @tf.function\n",
|
|
||||||
" def has_init_scope():\n",
|
|
||||||
" my_constant = tf.constant(1.)\n",
|
|
||||||
" with tf.init_scope():\n",
|
|
||||||
" added = my_constant * 2\n",
|
|
||||||
"The graph tensor has name: ",
|
|
||||||
name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>"
|
|
||||||
).c_str());
|
|
||||||
} else {
|
|
||||||
SWIG_exception_fail(
|
|
||||||
SWIG_TypeError,
|
|
||||||
tensorflow::strings::StrCat(
|
|
||||||
"provided list of inputs contains objects other "
|
|
||||||
"than 'EagerTensor'. Item ",
|
|
||||||
i, " is type: ", elem->ob_type->tp_name).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Temporary for the argout
|
|
||||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp) {
|
|
||||||
if (!PyInt_Check($input)) {
|
|
||||||
SWIG_exception_fail(SWIG_TypeError,
|
|
||||||
"expected an integer value (size of the number of "
|
|
||||||
"outputs of the operation)");
|
|
||||||
}
|
|
||||||
$1 = &temp;
|
|
||||||
long sz = PyInt_AsLong($input);
|
|
||||||
if (sz > 0) {
|
|
||||||
$1->resize(PyInt_AsLong($input), nullptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new Status object.
|
|
||||||
%typemap(in, numinputs=0) TF_Status *out_status {
|
|
||||||
$1 = GetStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(freearg) (TF_Status* out_status) {
|
|
||||||
ReturnStatus($1);
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
|
|
||||||
if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) {
|
|
||||||
SWIG_fail;
|
|
||||||
} else {
|
|
||||||
int num_outputs = $1->size();
|
|
||||||
Py_CLEAR($result);
|
|
||||||
$result = PyList_New(num_outputs);
|
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
|
||||||
PyObject *output;
|
|
||||||
output = EagerTensorFromHandle($1->at(i));
|
|
||||||
PyList_SetItem($result, i, output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SWIG usually unwraps the tuple that the native Python/C interface generates.
|
|
||||||
// Since we wanted to have a function with a variable length of arguments, we
|
|
||||||
// used the native Python/C interface directly (which by default supports
|
|
||||||
// passing all arguments as a tuple).
|
|
||||||
%native(TFE_Py_FastPathExecute) TFE_Py_FastPathExecute_C;
|
|
||||||
|
|
||||||
%include "tensorflow/python/eager/pywrap_tfe.h"
|
|
||||||
%include "tensorflow/c/c_api_experimental.h"
|
|
||||||
%include "tensorflow/c/eager/c_api_experimental.h"
|
|
||||||
|
|
||||||
// Clear all typemaps.
|
|
||||||
%typemap(out) TF_DataType;
|
|
||||||
%typemap(in) int64_t;
|
|
||||||
%typemap(out) int64_t;
|
|
||||||
%typemap(out) TF_AttrType;
|
|
||||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
|
||||||
%typemap(argout) unsigned char* is_list;
|
|
||||||
%typemap(in) const char* description;
|
|
||||||
%typemap(in) const char* label1;
|
|
||||||
%typemap(in) const char* label2;
|
|
||||||
%typemap(in) (TFE_Context*);
|
|
||||||
%typemap(out) (TFE_Context*);
|
|
||||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
|
|
||||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
|
||||||
%typemap(freearg) (TF_Status* out_status);
|
|
||||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
|
|
||||||
%typemap(in) (const void* proto);
|
|
||||||
|
|
||||||
%unignoreall
|
|
29
tensorflow/python/pywrap_tfe.py
Normal file
29
tensorflow/python/pywrap_tfe.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Python module for TFE ops and functions exported by pybind11.
|
||||||
|
|
||||||
|
This module is created because we are splitting out eager bindings from
|
||||||
|
pywrap_tensorflow. This is causing some issues where Graphs are not properly
|
||||||
|
initialized when running eager code. Once the graph architecture has been
|
||||||
|
removed from pywrap_tensorflow as well, we can remove this file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python._pywrap_tfe import *
|
@ -17,8 +17,6 @@ limitations under the License.
|
|||||||
* The includes are intentionally not alphabetically sorted, as the order of
|
* The includes are intentionally not alphabetically sorted, as the order of
|
||||||
* includes follows dependency order */
|
* includes follows dependency order */
|
||||||
|
|
||||||
%include "tensorflow/python/pywrap_tfe.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/client/tf_session.i"
|
%include "tensorflow/python/client/tf_session.i"
|
||||||
|
|
||||||
%include "tensorflow/python/lib/io/file_io.i"
|
%include "tensorflow/python/lib/io/file_io.i"
|
||||||
|
1099
tensorflow/python/tfe_wrapper.cc
Normal file
1099
tensorflow/python/tfe_wrapper.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@
|
|||||||
*perftools*gputools*
|
*perftools*gputools*
|
||||||
*tf_*
|
*tf_*
|
||||||
*TF_*
|
*TF_*
|
||||||
|
*Eager*
|
||||||
*TFE_*
|
*TFE_*
|
||||||
*nsync_*
|
*nsync_*
|
||||||
*stream_executor*
|
*stream_executor*
|
||||||
|
@ -4,6 +4,7 @@ tensorflow {
|
|||||||
*toco*;
|
*toco*;
|
||||||
*perftools*gputools*;
|
*perftools*gputools*;
|
||||||
*TF_*;
|
*TF_*;
|
||||||
|
*Eager*;
|
||||||
*TFE_*;
|
*TFE_*;
|
||||||
*nsync_*;
|
*nsync_*;
|
||||||
*stream_executor*;
|
*stream_executor*;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
[cpp_python_util] # util
|
[cpp_python_util] # util tfe
|
||||||
tensorflow::swig::IsSequence
|
tensorflow::swig::IsSequence
|
||||||
tensorflow::swig::IsSequenceOrComposite
|
tensorflow::swig::IsSequenceOrComposite
|
||||||
tensorflow::swig::IsCompositeTensor
|
tensorflow::swig::IsCompositeTensor
|
||||||
@ -17,6 +17,7 @@ tensorflow::swig::IsSequenceForData
|
|||||||
tensorflow::swig::FlattenForData
|
tensorflow::swig::FlattenForData
|
||||||
tensorflow::swig::AssertSameStructureForData
|
tensorflow::swig::AssertSameStructureForData
|
||||||
tensorflow::swig::RegisterType
|
tensorflow::swig::RegisterType
|
||||||
|
tensorflow::swig::IsEagerTensorSlow
|
||||||
|
|
||||||
[util_port] # util_port
|
[util_port] # util_port
|
||||||
tensorflow::IsGoogleCudaEnabled
|
tensorflow::IsGoogleCudaEnabled
|
||||||
@ -74,11 +75,12 @@ tensorflow::Status::code
|
|||||||
tensorflow::Status::error_message
|
tensorflow::Status::error_message
|
||||||
tensorflow::Status::ok()
|
tensorflow::Status::ok()
|
||||||
|
|
||||||
[core_cpu_impl] # device_lib
|
[core_cpu_impl] # device_lib tfe
|
||||||
tensorflow::Device::attributes
|
tensorflow::Device::attributes
|
||||||
tensorflow::DeviceFactory::AddDevices
|
tensorflow::DeviceFactory::AddDevices
|
||||||
tensorflow::SessionOptions::SessionOptions
|
tensorflow::SessionOptions::SessionOptions
|
||||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
|
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
|
||||||
|
tensorflow::DeviceFactory::ListAllPhysicalDevices
|
||||||
|
|
||||||
[protos_all] # device_lib, dtypes
|
[protos_all] # device_lib, dtypes
|
||||||
tensorflow::DataType_IsValid
|
tensorflow::DataType_IsValid
|
||||||
@ -123,3 +125,67 @@ tensorflow::make_safe
|
|||||||
|
|
||||||
[python_op_gen] # python_op_gen
|
[python_op_gen] # python_op_gen
|
||||||
tensorflow::GetPythonWrappers
|
tensorflow::GetPythonWrappers
|
||||||
|
|
||||||
|
[pywrap_tfe_lib] # tfe
|
||||||
|
tensorflow::TFE_TensorHandleCache
|
||||||
|
tensorflow::TFE_TensorHandleCache::Clear
|
||||||
|
EagerTensor_CheckExact
|
||||||
|
EagerTensorFromHandle
|
||||||
|
EagerTensor_Handle
|
||||||
|
TFE_Py_ExecuteCancelable
|
||||||
|
TFE_Py_RegisterExceptionClass
|
||||||
|
TFE_Py_RegisterVSpace
|
||||||
|
TFE_Py_RegisterFallbackExceptionClass
|
||||||
|
TFE_Py_RegisterGradientFunction
|
||||||
|
TFE_Py_RegisterJVPFunction
|
||||||
|
TFE_GetPythonString
|
||||||
|
TFE_Py_UID
|
||||||
|
TFE_DeleteContextCapsule
|
||||||
|
TFE_Py_InitEagerTensor
|
||||||
|
TFE_Py_SetEagerTensorProfiler
|
||||||
|
TFE_Py_TapeSetNew
|
||||||
|
TFE_Py_TapeSetRemove
|
||||||
|
TFE_Py_TapeSetAdd
|
||||||
|
TFE_Py_TapeSetIsEmpty
|
||||||
|
TFE_Py_TapeSetShouldRecordBackprop
|
||||||
|
TFE_Py_TapeSetPossibleGradientTypes
|
||||||
|
TFE_Py_TapeWatch
|
||||||
|
TFE_Py_TapeSetDeleteTrace
|
||||||
|
TFE_Py_TapeSetStopOnThread
|
||||||
|
TFE_Py_TapeSetRestartOnThread
|
||||||
|
TFE_Py_TapeSetIsStopped
|
||||||
|
TFE_Py_TapeSetRecordOperation
|
||||||
|
TFE_Py_TapeSetRecordOperationBackprop
|
||||||
|
TFE_Py_TapeSetRecordOperationForwardprop
|
||||||
|
TFE_Py_TapeVariableAccessed
|
||||||
|
TFE_Py_TapeWatchVariable
|
||||||
|
TFE_Py_TapeGradient
|
||||||
|
TFE_Py_FastPathExecute_C
|
||||||
|
TFE_Py_RecordGradient
|
||||||
|
TFE_Py_TapeWatchedVariables
|
||||||
|
TFE_Py_ForwardAccumulatorNew
|
||||||
|
TFE_Py_ForwardAccumulatorSetAdd
|
||||||
|
TFE_Py_ForwardAccumulatorSetRemove
|
||||||
|
TFE_Py_ForwardAccumulatorWatch
|
||||||
|
TFE_Py_ForwardAccumulatorJVP
|
||||||
|
TFE_Py_ForwardAccumulatorPushState
|
||||||
|
TFE_Py_ForwardAccumulatorPopState
|
||||||
|
TFE_Py_PackJVPs
|
||||||
|
TFE_Py_TensorShapeSlice
|
||||||
|
TFE_Py_TensorShapeOnDevice
|
||||||
|
TFE_Py_EncodeArg
|
||||||
|
TFE_Py_EnableInteractivePythonLogging
|
||||||
|
TFE_Py_SetEagerContext
|
||||||
|
|
||||||
|
[eager_executor] # tfe
|
||||||
|
tensorflow::EagerExecutor::~EagerExecutor
|
||||||
|
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||||
|
|
||||||
|
[profiler_session] # tfe
|
||||||
|
tensorflow::ProfilerSession::~ProfilerSession
|
||||||
|
|
||||||
|
[tf_status_helper] # tfe
|
||||||
|
tensorflow::Set_TF_Status_from_Status
|
||||||
|
|
||||||
|
[context] # tfe
|
||||||
|
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||||
|
Loading…
Reference in New Issue
Block a user