Export the Eager classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 286110711 Change-Id: I7bf6f6f4ce1d6bf3e8a3e40ef4a83f82333800f6
This commit is contained in:
parent
03341c4342
commit
7bd345bcbb
@ -53,6 +53,20 @@ filegroup(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"tf_status_helper.h",
|
||||
"tf_status_internal.h",
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
hdrs = [
|
||||
|
@ -88,6 +88,18 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api_experimental.h"],
|
||||
|
@ -439,6 +439,23 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"attr_builder.h",
|
||||
"context.h",
|
||||
"eager_executor.h",
|
||||
"eager_operation.h",
|
||||
"kernel_and_device.h",
|
||||
"tensor_handle.h",
|
||||
"tensor_handle_data.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "srcs",
|
||||
srcs = glob(
|
||||
|
@ -783,3 +783,20 @@ tf_cc_test(
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"call_options.h",
|
||||
"message_wrappers.h",
|
||||
"rendezvous_mgr_interface.h",
|
||||
"server_lib.h",
|
||||
"worker_cache.h",
|
||||
"worker_env.h",
|
||||
"worker_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -216,3 +216,16 @@ cc_library(
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"eager_client.h",
|
||||
"remote_tensor_handle.h",
|
||||
"remote_tensor_handle_data.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -853,6 +853,18 @@ tf_cc_tests(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"op_gen_lib.h",
|
||||
"rendezvous.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
# All framewrok protos are self-contained, i.e. they only import other
|
||||
# protos from the same package, so we can build the protos here and then
|
||||
# link them from core:protos_all without circular dependencies.
|
||||
|
@ -523,3 +523,14 @@ tf_cc_test(
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"profiler_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -43,6 +43,17 @@ tf_cuda_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"profiler_session.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "traceme",
|
||||
hdrs = ["traceme.h"],
|
||||
|
@ -171,6 +171,7 @@ py_library(
|
||||
":platform",
|
||||
":proto_ops",
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tfe",
|
||||
":rnn_ops_gen",
|
||||
":saver_test_utils",
|
||||
":script_ops",
|
||||
@ -251,6 +252,7 @@ py_library(
|
||||
deps = [
|
||||
":_pywrap_util_port",
|
||||
":lib",
|
||||
":pywrap_tfe",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"@absl_py//absl:app",
|
||||
@ -477,13 +479,13 @@ cc_library(
|
||||
cc_library(
|
||||
name = "pybind11_status",
|
||||
hdrs = [
|
||||
"lib/core/py_exception_registry.h",
|
||||
"lib/core/pybind11_status.h",
|
||||
"//tensorflow/c:headers",
|
||||
],
|
||||
features = ["-parse_headers"],
|
||||
visibility = tf_external_workspace_visible(visibility),
|
||||
deps = [
|
||||
":py_exception_registry",
|
||||
"//tensorflow/c:tf_status_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -1110,6 +1112,7 @@ py_library(
|
||||
":lib",
|
||||
":platform",
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tfe",
|
||||
":random_seed",
|
||||
":sparse_tensor",
|
||||
":tensor_spec",
|
||||
@ -5492,7 +5495,6 @@ tf_py_wrap_cc(
|
||||
"lib/io/py_record_reader.i",
|
||||
"lib/io/py_record_writer.i",
|
||||
"platform/base.i",
|
||||
"pywrap_tfe.i",
|
||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||
],
|
||||
# add win_def_file for pywrap_tensorflow
|
||||
@ -5573,7 +5575,12 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
||||
":safe_ptr", # checkpoint_reader
|
||||
":python_op_gen", # python_op_gen
|
||||
":bfloat16_lib", # bfloat16
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
|
||||
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
|
||||
"//tensorflow/core/common_runtime/eager:context", # tfe
|
||||
"//tensorflow/core/profiler/lib:profiler_session", # tfe
|
||||
"//tensorflow/c:tf_status_helper", # tfe
|
||||
]
|
||||
|
||||
# Filter the DEF file to reduce the number of symbols to 64K or less.
|
||||
@ -7555,6 +7562,67 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pywrap_tfe",
|
||||
srcs = ["pywrap_tfe.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":_pywrap_tfe",
|
||||
":pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tfe",
|
||||
srcs = ["tfe_wrapper.cc"],
|
||||
hdrs = [
|
||||
"lib/core/safe_ptr.h",
|
||||
"util/util.h",
|
||||
":py_exception_registry_hdr",
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c:pywrap_eager_hdrs",
|
||||
"//tensorflow/c/eager:headers",
|
||||
"//tensorflow/c/eager:pywrap_eager_hdrs",
|
||||
"//tensorflow/core/common_runtime/eager:pywrap_eager_hdrs",
|
||||
"//tensorflow/core/distributed_runtime:pywrap_eager_hdrs",
|
||||
"//tensorflow/core/distributed_runtime/eager:pywrap_eager_hdrs",
|
||||
"//tensorflow/core/framework:pywrap_eager_hdrs",
|
||||
"//tensorflow/core/profiler/internal:pywrap_eager_hdrs",
|
||||
"//tensorflow/core/profiler/lib:pywrap_eager_hdrs",
|
||||
"//tensorflow/python/eager:pywrap_eager_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_tfe",
|
||||
deps = [
|
||||
":pybind11_lib",
|
||||
":pybind11_status",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@pybind11",
|
||||
"//third_party/python_runtime:headers",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:platform",
|
||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/core:eager_service_proto_cc",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
otherwise = [
|
||||
"//tensorflow/core:eager_service_proto_cc_headers_only",
|
||||
"//tensorflow/core:master_proto_cc_headers_only",
|
||||
"//tensorflow/core:worker_proto_cc_headers_only",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_graph_analyzer",
|
||||
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
@ -23,6 +24,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
// We were getting lucky on imports with safe_ptr.h being placed prior to
|
||||
// tf_session which imported safe_ptr. We also need pywrap_tfe.h to cast
|
||||
// one of the inputs to a graph function from a Python string to const char*.
|
||||
|
||||
|
||||
// Helper function to convert a Python list of Tensors to a C++ vector of
|
||||
// TF_Outputs.
|
||||
@ -78,6 +86,9 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
|
||||
|
||||
%}
|
||||
|
||||
%include "tensorflow/c/tf_datatype.h"
|
||||
%include "tensorflow/c/tf_status.h"
|
||||
|
||||
%include "tensorflow/python/client/tf_sessionrun_wrapper.i"
|
||||
|
||||
// Required to use PyArray_* functions.
|
||||
@ -85,6 +96,14 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
|
||||
tensorflow::ImportNumpy();
|
||||
%}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* op_name {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
|
||||
// TensorFlow version and GraphDef versions
|
||||
%constant const char* __version__ = TF_VERSION_STRING;
|
||||
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
|
||||
@ -174,6 +193,12 @@ tensorflow::ImportNumpy();
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_OperationGetControlInputs_wrapper;
|
||||
|
||||
|
||||
// Migrate one function from pywrap_tfe.i
|
||||
%include "tensorflow/c/c_api_experimental.h"
|
||||
%unignore TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||
%noexception TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||
|
||||
// Build a Python list of TF_Operation* and return it.
|
||||
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
|
||||
$result = PyList_New($1.size());
|
||||
|
@ -268,7 +268,7 @@ py_library(
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:training",
|
||||
|
@ -24,7 +24,7 @@ import functools
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
@ -944,7 +944,7 @@ class _MirroredReplicaThread(threading.Thread):
|
||||
self.record_thread_local_summary_state()
|
||||
self.record_thread_local_eager_context_state()
|
||||
self.context_device_policy = (
|
||||
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
|
||||
pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
|
||||
ctx._context_handle)) # pylint: disable=protected-access
|
||||
self.graph = ops.get_default_graph()
|
||||
with ops.init_scope():
|
||||
|
@ -56,6 +56,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"pywrap_tensor_conversion.h",
|
||||
"pywrap_tfe.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
# Transitive dependencies of this target will be included in the pip package.
|
||||
py_library(
|
||||
name = "eager_pip",
|
||||
@ -90,7 +102,7 @@ py_library(
|
||||
deps = [
|
||||
":context",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
],
|
||||
)
|
||||
|
||||
@ -100,7 +112,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
],
|
||||
)
|
||||
|
||||
@ -121,7 +133,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
],
|
||||
)
|
||||
|
||||
@ -131,13 +143,14 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":eager_util",
|
||||
":executor",
|
||||
":monitoring",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:device_spec",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
@ -164,8 +177,8 @@ py_library(
|
||||
"//third_party/py/tf_agents:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:c_api_util",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
":eager_util",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -187,7 +200,8 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
":eager_util",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -209,7 +223,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
":eager_util",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
],
|
||||
)
|
||||
|
||||
@ -298,7 +313,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
@ -410,7 +425,7 @@ py_library(
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
@ -496,7 +511,7 @@ py_library(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:unconnected_gradients",
|
||||
"//tensorflow/python:util",
|
||||
@ -524,7 +539,7 @@ py_library(
|
||||
deps = [
|
||||
":forwardprop_util",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -535,7 +550,18 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "eager_util",
|
||||
srcs = ["eager_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -552,7 +578,7 @@ cuda_py_test(
|
||||
":remote",
|
||||
":test",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python/keras",
|
||||
"//third_party/py/numpy",
|
||||
@ -637,7 +663,7 @@ tf_py_test(
|
||||
":test",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:test_ops",
|
||||
"//third_party/py/numpy",
|
||||
@ -649,7 +675,7 @@ py_library(
|
||||
srcs = ["imperative_grad.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:unconnected_gradients",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
|
@ -24,7 +24,7 @@ import sys
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -71,19 +71,25 @@ def op_attr_type(op_type, attr_name):
|
||||
except KeyError:
|
||||
context.ensure_initialized()
|
||||
h = context.context()._handle # pylint: disable=protected-access
|
||||
attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name)
|
||||
attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
|
||||
_op_attr_type_cache[(op_type, attr_name)] = attr_type
|
||||
return attr_type
|
||||
|
||||
|
||||
def make_attr(attr_type, value):
|
||||
if attr_type == pywrap_tensorflow.TF_ATTR_TYPE:
|
||||
# pybind11 enums do not return the raw value like SWIG enums do. They are
|
||||
# useful when comparing amongst each other but not direct integers as we are
|
||||
# doing in most tests.
|
||||
# https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
|
||||
# TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
|
||||
# from integer value to class.
|
||||
if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
|
||||
return dtypes.as_dtype(value)
|
||||
elif attr_type == [pywrap_tensorflow.TF_ATTR_TYPE]:
|
||||
elif attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
|
||||
return [dtypes.as_dtype(v) for v in value]
|
||||
elif attr_type == pywrap_tensorflow.TF_ATTR_SHAPE:
|
||||
elif attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
|
||||
return tensor_shape.as_shape(value).as_proto()
|
||||
elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]:
|
||||
elif attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
|
||||
return [tensor_shape.as_shape(v).as_proto() for v in value]
|
||||
elif isinstance(value, str):
|
||||
return value.encode()
|
||||
@ -141,16 +147,15 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
|
||||
return grad_fn(mock_op, *out_grads)
|
||||
|
||||
|
||||
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
|
||||
pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
|
||||
|
||||
|
||||
def _must_record_gradient():
|
||||
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
|
||||
return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
|
||||
|
||||
|
||||
def _record_gradient(op_name, inputs, attrs, results):
|
||||
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
|
||||
results)
|
||||
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
|
||||
|
||||
|
||||
execute.must_record_gradient = _must_record_gradient
|
||||
@ -688,7 +693,7 @@ _default_vspace = imperative_grad.VSpace(
|
||||
zeros_like_fn=default_gradient.zeros_like,
|
||||
ones_like_fn=default_gradient.ones_like,
|
||||
graph_shape_fn=gen_array_ops.shape)
|
||||
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
|
||||
pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
|
||||
|
||||
|
||||
def _handle_or_self(x):
|
||||
|
@ -21,7 +21,7 @@ import functools
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -1014,19 +1014,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testGetAttrType(self):
|
||||
typ = backprop.op_attr_type('Add', 'T')
|
||||
self.assertEqual(typ, pywrap_tensorflow.TF_ATTR_TYPE)
|
||||
self.assertEqual(typ, int(pywrap_tfe.TF_ATTR_TYPE))
|
||||
|
||||
def testGetAttrList(self):
|
||||
typ = backprop.op_attr_type('MaxPool', 'ksize')
|
||||
self.assertEqual(typ, [pywrap_tensorflow.TF_ATTR_INT])
|
||||
self.assertEqual(typ, [int(pywrap_tfe.TF_ATTR_INT)])
|
||||
|
||||
def testMakeAttrType(self):
|
||||
self.assertEqual(dtypes.float32,
|
||||
backprop.make_attr(pywrap_tensorflow.TF_ATTR_TYPE, 1))
|
||||
backprop.make_attr(int(pywrap_tfe.TF_ATTR_TYPE), 1))
|
||||
|
||||
def testMakeAttrTypeList(self):
|
||||
self.assertEqual([dtypes.float32],
|
||||
backprop.make_attr([pywrap_tensorflow.TF_ATTR_TYPE], [1]))
|
||||
backprop.make_attr([int(pywrap_tfe.TF_ATTR_TYPE)], [1]))
|
||||
|
||||
def testMulType(self):
|
||||
|
||||
@ -1040,7 +1040,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
def testMakeAttrShape(self):
|
||||
for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
|
||||
expected = tensor_shape.TensorShape(s).as_proto()
|
||||
actual = backprop.make_attr(pywrap_tensorflow.TF_ATTR_SHAPE, s)
|
||||
actual = backprop.make_attr(int(pywrap_tfe.TF_ATTR_SHAPE), s)
|
||||
self.assertEqual(
|
||||
expected,
|
||||
actual,
|
||||
@ -1051,7 +1051,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
|
||||
self.assertEqual(
|
||||
[tensor_shape.TensorShape(s).as_proto() for s in shape_list],
|
||||
backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list))
|
||||
backprop.make_attr([int(pywrap_tfe.TF_ATTR_SHAPE)], shape_list))
|
||||
|
||||
def testArgsGradientFunction(self):
|
||||
|
||||
|
@ -39,7 +39,7 @@ import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import backprop # pylint: disable=unused-import
|
||||
from tensorflow.python.eager import context
|
||||
@ -76,10 +76,10 @@ def c_tfe_py_fastpath_execute(a,
|
||||
assert ctx.executing_eagerly(
|
||||
), "The prototype doesn't contain C code for graph construction"
|
||||
try:
|
||||
return pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "MatMul", name,
|
||||
ctx.op_callbacks, a, b, "transpose_a", transpose_a,
|
||||
"transpose_b", transpose_b)
|
||||
return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", name, ctx.op_callbacks,
|
||||
a, b, "transpose_a", transpose_a,
|
||||
"transpose_b", transpose_b)
|
||||
except core._NotOkStatusException as e:
|
||||
if name is not None:
|
||||
message = e.message + " name: " + name
|
||||
@ -339,8 +339,7 @@ class MicroBenchmarks(test.Benchmark):
|
||||
inputs = [m]
|
||||
|
||||
def f():
|
||||
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "Identity", inputs,
|
||||
attrs, 1)
|
||||
pywrap_tfe.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, attrs, 1)
|
||||
|
||||
self._run(f, 30000)
|
||||
|
||||
@ -406,8 +405,7 @@ class MicroBenchmarks(test.Benchmark):
|
||||
m.dtype.as_datatype_enum)
|
||||
|
||||
def func():
|
||||
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs,
|
||||
attrs, 1)
|
||||
pywrap_tfe.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, attrs, 1)
|
||||
|
||||
self._run(func, num_iters)
|
||||
|
||||
|
@ -18,27 +18,27 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
|
||||
|
||||
class CancellationManager(object):
|
||||
"""A mechanism for cancelling blocking computation."""
|
||||
|
||||
def __init__(self):
|
||||
self._impl = pywrap_tensorflow.TFE_NewCancellationManager()
|
||||
self._impl = pywrap_tfe.TFE_NewCancellationManager()
|
||||
|
||||
@property
|
||||
def is_cancelled(self):
|
||||
"""Returns `True` if `CancellationManager.start_cancel` has been called."""
|
||||
return pywrap_tensorflow.TFE_CancellationManagerIsCancelled(self._impl)
|
||||
return pywrap_tfe.TFE_CancellationManagerIsCancelled(self._impl)
|
||||
|
||||
def start_cancel(self):
|
||||
"""Cancels blocking operations that have been registered with this object."""
|
||||
pywrap_tensorflow.TFE_CancellationManagerStartCancel(self._impl)
|
||||
pywrap_tfe.TFE_CancellationManagerStartCancel(self._impl)
|
||||
|
||||
def get_cancelable_function(self, concrete_function):
|
||||
# pylint: disable=protected-access
|
||||
return concrete_function._experimental_with_cancellation_manager(self)
|
||||
|
||||
def __del__(self):
|
||||
pywrap_tensorflow.TFE_DeleteCancellationManager(self._impl)
|
||||
pywrap_tfe.TFE_DeleteCancellationManager(self._impl)
|
||||
|
@ -29,11 +29,11 @@ import six
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.eager import executor
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import is_in_graph_mode
|
||||
@ -54,17 +54,17 @@ _starting_device_spec = pydev.DeviceSpec.from_string("")
|
||||
|
||||
_MAXINT32 = 2**31 - 1
|
||||
|
||||
DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
|
||||
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
|
||||
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
|
||||
DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
|
||||
DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
|
||||
DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
|
||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
|
||||
pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
|
||||
pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
|
||||
|
||||
SYNC = 0
|
||||
ASYNC = 1
|
||||
|
||||
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
|
||||
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
|
||||
MIRRORING_NONE = pywrap_tfe.TFE_MIRRORING_NONE
|
||||
MIRRORING_ALL = pywrap_tfe.TFE_MIRRORING_ALL
|
||||
|
||||
_KEEP_ALIVE_SECS = 600
|
||||
|
||||
@ -444,7 +444,7 @@ class Context(object):
|
||||
self._rng = random.Random(seed)
|
||||
# Also clear the kernel cache, to reset any existing seeds
|
||||
if self._context_handle is not None:
|
||||
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
|
||||
pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
|
||||
|
||||
def _internal_operation_seed(self):
|
||||
"""Returns a fake operation seed.
|
||||
@ -463,12 +463,11 @@ class Context(object):
|
||||
# Store list of devices
|
||||
logical_devices = []
|
||||
context_devices = []
|
||||
device_list = pywrap_tensorflow.TFE_ContextListDevices(
|
||||
self._context_handle)
|
||||
device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
|
||||
try:
|
||||
self._num_gpus = 0
|
||||
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
|
||||
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
|
||||
for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
|
||||
dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
|
||||
context_devices.append(pydev.canonical_name(dev_name))
|
||||
spec = pydev.DeviceSpec.from_string(dev_name)
|
||||
# If the job is localhost, we assume that the cluster has not yet been
|
||||
@ -477,14 +476,14 @@ class Context(object):
|
||||
spec = spec.replace(job=None, replica=None, task=None)
|
||||
logical_devices.append(
|
||||
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
|
||||
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
|
||||
dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
|
||||
if dev_type == "GPU":
|
||||
self._num_gpus += 1
|
||||
|
||||
finally:
|
||||
self._logical_devices = logical_devices
|
||||
self._context_devices = context_devices
|
||||
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
|
||||
pywrap_tfe.TF_DeleteDeviceList(device_list)
|
||||
|
||||
def ensure_initialized(self):
|
||||
"""Initialize handle and devices if not already done so."""
|
||||
@ -494,36 +493,34 @@ class Context(object):
|
||||
if self._initialized:
|
||||
return
|
||||
assert self._context_devices is None
|
||||
opts = pywrap_tensorflow.TFE_NewContextOptions()
|
||||
opts = pywrap_tfe.TFE_NewContextOptions()
|
||||
try:
|
||||
config_str = self.config.SerializeToString()
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
|
||||
pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
|
||||
if self._device_policy is not None:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
opts, self._device_policy)
|
||||
if self._mirroring_policy is not None:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetMirroringPolicy(
|
||||
pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
|
||||
opts, self._mirroring_policy)
|
||||
if self._default_is_async == ASYNC:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
|
||||
pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
|
||||
if self._lazy_remote_inputs_copy is not None:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
opts, self._lazy_remote_inputs_copy)
|
||||
context_handle = pywrap_tensorflow.TFE_NewContext(opts)
|
||||
context_handle = pywrap_tfe.TFE_NewContext(opts)
|
||||
finally:
|
||||
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
|
||||
pywrap_tfe.TFE_DeleteContextOptions(opts)
|
||||
assert not (self._server_def and self._collective_ops_server_def), (
|
||||
"Cannot enable remote execution as well as collective ops at the "
|
||||
"moment. If this is important to you, please file an issue.")
|
||||
if self._server_def is not None:
|
||||
server_def_str = self._server_def.SerializeToString()
|
||||
pywrap_tensorflow.TFE_ContextSetServerDef(context_handle,
|
||||
_KEEP_ALIVE_SECS,
|
||||
server_def_str)
|
||||
pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
|
||||
server_def_str)
|
||||
elif self._collective_ops_server_def is not None:
|
||||
server_def_str = self._collective_ops_server_def.SerializeToString()
|
||||
pywrap_tensorflow.TFE_EnableCollectiveOps(context_handle,
|
||||
server_def_str)
|
||||
pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
|
||||
|
||||
self._context_handle = context_handle
|
||||
self._initialize_logical_devices()
|
||||
@ -532,7 +529,7 @@ class Context(object):
|
||||
def _clear_caches(self):
|
||||
self.ones_rank_cache().flush()
|
||||
self.zeros_cache().flush()
|
||||
pywrap_tensorflow.TFE_ClearScalarCache()
|
||||
pywrap_tfe.TFE_ClearScalarCache()
|
||||
|
||||
def get_server_def(self):
|
||||
return self._server_def
|
||||
@ -563,8 +560,8 @@ class Context(object):
|
||||
|
||||
if self._context_handle:
|
||||
server_def_str = server_def.SerializeToString()
|
||||
pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
|
||||
keep_alive_secs, server_def_str)
|
||||
pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
|
||||
server_def_str)
|
||||
self._initialize_logical_devices()
|
||||
|
||||
# Clear all the caches in case there are remote tensors in them.
|
||||
@ -592,9 +589,8 @@ class Context(object):
|
||||
|
||||
if self._context_handle:
|
||||
server_def_str = server_def.SerializeToString()
|
||||
pywrap_tensorflow.TFE_ContextUpdateServerDef(self._context_handle,
|
||||
keep_alive_secs,
|
||||
server_def_str)
|
||||
pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
|
||||
keep_alive_secs, server_def_str)
|
||||
self._initialize_logical_devices()
|
||||
|
||||
self._clear_caches()
|
||||
@ -614,8 +610,7 @@ class Context(object):
|
||||
"""
|
||||
# TODO(yuefengz): support checking multiple workers.
|
||||
if self._context_handle:
|
||||
return pywrap_tensorflow.TFE_ContextCheckAlive(self._context_handle,
|
||||
worker_name)
|
||||
return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
|
||||
else:
|
||||
raise ValueError("Context is not initialized.")
|
||||
|
||||
@ -808,8 +803,8 @@ class Context(object):
|
||||
self.executor.wait()
|
||||
executor_new = executor.new_executor(enable_async)
|
||||
self._thread_local_data.executor = executor_new
|
||||
pywrap_tensorflow.TFE_ContextSetExecutorForThread(
|
||||
self._context_handle, executor_new.handle())
|
||||
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
|
||||
executor_new.handle())
|
||||
else:
|
||||
self._default_is_async = enable_async
|
||||
|
||||
@ -823,13 +818,12 @@ class Context(object):
|
||||
def executor(self):
|
||||
ensure_initialized()
|
||||
return executor.Executor(
|
||||
pywrap_tensorflow.TFE_ContextGetExecutorForThread(self._context_handle))
|
||||
pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
|
||||
|
||||
@executor.setter
|
||||
def executor(self, e):
|
||||
ensure_initialized()
|
||||
pywrap_tensorflow.TFE_ContextSetExecutorForThread(self._context_handle,
|
||||
e.handle())
|
||||
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
@ -1015,7 +1009,7 @@ class Context(object):
|
||||
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
|
||||
"""
|
||||
self.ensure_initialized()
|
||||
pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)
|
||||
pywrap_tfe.TFE_ContextAddFunction(self._handle, fn)
|
||||
|
||||
def add_function_def(self, fdef):
|
||||
"""Add a function definition to the context.
|
||||
@ -1028,8 +1022,8 @@ class Context(object):
|
||||
"""
|
||||
self.ensure_initialized()
|
||||
fdef_string = fdef.SerializeToString()
|
||||
pywrap_tensorflow.TFE_ContextAddFunctionDef(
|
||||
self._handle, fdef_string, len(fdef_string))
|
||||
pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
|
||||
len(fdef_string))
|
||||
|
||||
def remove_function(self, name):
|
||||
"""Remove a function from the context.
|
||||
@ -1040,12 +1034,12 @@ class Context(object):
|
||||
name: function signature name.
|
||||
"""
|
||||
self.ensure_initialized()
|
||||
pywrap_tensorflow.TFE_ContextRemoveFunction(self._handle, name)
|
||||
pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
|
||||
|
||||
def has_function(self, name):
|
||||
"""Check if a function `name` is registered."""
|
||||
self.ensure_initialized()
|
||||
return bool(pywrap_tensorflow.TFE_ContextHasFunction(self._handle, name))
|
||||
return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
|
||||
|
||||
def add_op_callback(self, callback):
|
||||
"""Add a post-op callback to the context.
|
||||
@ -1101,7 +1095,7 @@ class Context(object):
|
||||
if self._physical_devices is not None:
|
||||
return
|
||||
|
||||
devs = pywrap_tensorflow.TF_ListPhysicalDevices()
|
||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||
self._physical_devices = [
|
||||
PhysicalDevice(name=d.decode(),
|
||||
device_type=d.decode().split(":")[1]) for d in devs]
|
||||
@ -1434,7 +1428,7 @@ class Context(object):
|
||||
def device_policy(self):
|
||||
# Only get the policy from the context if it has already been initialized
|
||||
if self._context_handle is not None:
|
||||
return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(self._handle)
|
||||
return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
|
||||
|
||||
return self._device_policy
|
||||
|
||||
@ -1448,14 +1442,14 @@ class Context(object):
|
||||
|
||||
# Only set the policy if the context has already been initialized
|
||||
if self._context_handle is not None:
|
||||
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
self._handle, self._device_policy)
|
||||
|
||||
@property
|
||||
def mirroring_policy(self):
|
||||
# Only get the policy from the context if it has already been initialized
|
||||
if self._context_handle is not None:
|
||||
return pywrap_tensorflow.TFE_ContextGetMirroringPolicy(self._handle)
|
||||
return pywrap_tfe.TFE_ContextGetMirroringPolicy(self._handle)
|
||||
|
||||
return self._mirroring_policy
|
||||
|
||||
@ -1469,7 +1463,7 @@ class Context(object):
|
||||
|
||||
# Only set the policy if the context has already been initialized
|
||||
if self._context_handle is not None:
|
||||
pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
pywrap_tfe.TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
self._handle, self._mirroring_policy)
|
||||
|
||||
@property
|
||||
@ -1495,13 +1489,13 @@ class Context(object):
|
||||
and to stop tracing call context.disable_run_metadata().
|
||||
"""
|
||||
self.ensure_initialized()
|
||||
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
|
||||
pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
|
||||
|
||||
def disable_run_metadata(self):
|
||||
"""Disables tracing of op execution via RunMetadata."""
|
||||
if not self._context_handle:
|
||||
return
|
||||
pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
|
||||
pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
|
||||
|
||||
def enable_graph_collection(self):
|
||||
"""Enables graph collection of executed functions.
|
||||
@ -1510,13 +1504,13 @@ class Context(object):
|
||||
and to stop collecting graphs call context.disable_graph_collection().
|
||||
"""
|
||||
self.ensure_initialized()
|
||||
pywrap_tensorflow.TFE_ContextEnableGraphCollection(self._handle)
|
||||
pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
|
||||
|
||||
def disable_graph_collection(self):
|
||||
"""Disables graph collection of executed functions."""
|
||||
if not self._context_handle:
|
||||
return
|
||||
pywrap_tensorflow.TFE_ContextDisableGraphCollection(self._context_handle)
|
||||
pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
|
||||
|
||||
def export_run_metadata(self):
|
||||
"""Returns a RunMetadata proto with accumulated information.
|
||||
@ -1530,9 +1524,8 @@ class Context(object):
|
||||
if not self._context_handle:
|
||||
return None
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TFE_ContextExportRunMetadata(
|
||||
self._context_handle, buffer_)
|
||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
|
||||
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
run_metadata.ParseFromString(compat.as_bytes(proto_data))
|
||||
return run_metadata
|
||||
@ -1543,10 +1536,10 @@ class Context(object):
|
||||
return self._context_switches
|
||||
|
||||
def start_step(self):
|
||||
pywrap_tensorflow.TFE_ContextStartStep(self._handle)
|
||||
pywrap_tfe.TFE_ContextStartStep(self._handle)
|
||||
|
||||
def end_step(self):
|
||||
pywrap_tensorflow.TFE_ContextEndStep(self._handle)
|
||||
pywrap_tfe.TFE_ContextEndStep(self._handle)
|
||||
|
||||
|
||||
class _EagerDeviceContext(object):
|
||||
@ -1608,7 +1601,7 @@ _context_lock = threading.Lock()
|
||||
|
||||
def _set_context_locked(ctx):
|
||||
global _context
|
||||
pywrap_tensorflow.TFE_Py_SetEagerContext(ctx)
|
||||
pywrap_tfe.TFE_Py_SetEagerContext(ctx)
|
||||
_context = ctx
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
# Trace of execution and memory usage.
|
||||
@ -46,7 +46,7 @@ class _NotOkStatusException(Exception):
|
||||
return "%s: %s" % (e.__class__.__name__, e)
|
||||
|
||||
|
||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
|
||||
pywrap_tfe.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
|
||||
|
||||
|
||||
class _FallbackException(Exception):
|
||||
@ -71,4 +71,4 @@ class _SymbolicException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
pywrap_tensorflow.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)
|
||||
pywrap_tfe.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)
|
||||
|
@ -26,7 +26,7 @@ import threading
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -602,8 +602,8 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testRegisterExceptionClass(self):
|
||||
with self.assertRaises(TypeError):
|
||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
|
||||
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
|
||||
pywrap_tfe.TFE_Py_RegisterExceptionClass(str)
|
||||
pywrap_tfe.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
|
||||
|
||||
# TODO(agarwal): add tests passing incorrect typed values to attrs.
|
||||
def testExecuteBasic(self):
|
||||
|
61
tensorflow/python/eager/eager_util.py
Normal file
61
tensorflow/python/eager/eager_util.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for using the TensorFlow Eager using the C API."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tfe as c_api
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
# We temporarily need a duplicate tf_buffer function in eager_util. The
|
||||
# c_api_util is still relying on SWIG and is thus incompatible until
|
||||
# we migrate over. We can delete this once we migrate tf_session.i
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def tf_buffer(data=None):
|
||||
"""Context manager that creates and deletes TF_Buffer.
|
||||
|
||||
Example usage:
|
||||
with tf_buffer() as buf:
|
||||
# get serialized graph def into buf
|
||||
...
|
||||
proto_data = c_api.TF_GetBuffer(buf)
|
||||
graph_def.ParseFromString(compat.as_bytes(proto_data))
|
||||
# buf has been deleted
|
||||
|
||||
with tf_buffer(some_string) as buf:
|
||||
c_api.TF_SomeFunction(buf)
|
||||
# buf has been deleted
|
||||
|
||||
Args:
|
||||
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
|
||||
yielded buffer will contain this data.
|
||||
|
||||
Yields:
|
||||
Created TF_Buffer
|
||||
"""
|
||||
if data:
|
||||
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
|
||||
else:
|
||||
buf = c_api.TF_NewBuffer()
|
||||
try:
|
||||
yield buf
|
||||
finally:
|
||||
c_api.TF_DeleteBuffer(buf)
|
@ -22,7 +22,7 @@ import six
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.core.framework import tensor_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -56,9 +56,8 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
ctx.ensure_initialized()
|
||||
tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
|
||||
op_name, inputs, attrs,
|
||||
num_outputs)
|
||||
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
|
||||
inputs, attrs, num_outputs)
|
||||
except core._NotOkStatusException as e:
|
||||
if name is not None:
|
||||
message = e.message + " name: " + name
|
||||
@ -111,9 +110,10 @@ def execute_with_cancellation(op_name,
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
ctx.ensure_initialized()
|
||||
tensors = pywrap_tensorflow.TFE_Py_ExecuteCancelable(
|
||||
ctx._handle, device_name, op_name, inputs, attrs,
|
||||
cancellation_manager._impl, num_outputs)
|
||||
tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
|
||||
op_name, inputs, attrs,
|
||||
cancellation_manager._impl,
|
||||
num_outputs)
|
||||
except core._NotOkStatusException as e:
|
||||
if name is not None:
|
||||
message = e.message + " name: " + name
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
|
||||
|
||||
class Executor(object):
|
||||
@ -45,8 +45,8 @@ class Executor(object):
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
# pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
||||
pywrap_tensorflow.TFE_DeleteExecutor(self._handle)
|
||||
# pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
||||
pywrap_tfe.TFE_DeleteExecutor(self._handle)
|
||||
except TypeError:
|
||||
# Suppress some exceptions, mainly for the case when we're running on
|
||||
# module deletion. Things that can go wrong include the pywrap module
|
||||
@ -57,20 +57,20 @@ class Executor(object):
|
||||
# partially unloaded.
|
||||
|
||||
def is_async(self):
|
||||
return pywrap_tensorflow.TFE_ExecutorIsAsync(self._handle)
|
||||
return pywrap_tfe.TFE_ExecutorIsAsync(self._handle)
|
||||
|
||||
def handle(self):
|
||||
return self._handle
|
||||
|
||||
def wait(self):
|
||||
"""Waits for ops dispatched in this executor to finish."""
|
||||
pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
||||
pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
||||
|
||||
def clear_error(self):
|
||||
"""Clears errors raised in this executor during execution."""
|
||||
pywrap_tensorflow.TFE_ExecutorClearError(self._handle)
|
||||
pywrap_tfe.TFE_ExecutorClearError(self._handle)
|
||||
|
||||
|
||||
def new_executor(enable_async):
|
||||
handle = pywrap_tensorflow.TFE_NewExecutor(enable_async)
|
||||
handle = pywrap_tfe.TFE_NewExecutor(enable_async)
|
||||
return Executor(handle)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -166,7 +166,8 @@ def _jvp_dispatch(op_name, attr_tuple, inputs, outputs, tangents):
|
||||
return _jvp_relaxed_shapes(
|
||||
op_name, attr_tuple, inputs, outputs, tangents)
|
||||
|
||||
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
|
||||
|
||||
pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
|
||||
|
||||
|
||||
@tf_export("autodiff.ForwardAccumulator", v1=[])
|
||||
@ -300,7 +301,7 @@ class ForwardAccumulator(object):
|
||||
ValueError: If the same tensor or variable is specified multiple times in
|
||||
`primals`.
|
||||
"""
|
||||
self._accumulator = pywrap_tensorflow.TFE_Py_ForwardAccumulatorNew()
|
||||
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
|
||||
self._recording = False
|
||||
primal_ids = set()
|
||||
for primal in nest.flatten(primals):
|
||||
@ -323,13 +324,13 @@ class ForwardAccumulator(object):
|
||||
def _push_accumulator(self):
|
||||
if self._recording:
|
||||
raise ValueError("Accumulator is already recording.")
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
|
||||
pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
|
||||
self._recording = True
|
||||
|
||||
def _pop_accumulator(self):
|
||||
if not self._recording:
|
||||
raise ValueError("Accumulator is not recording.")
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
|
||||
pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
|
||||
self._recording = False
|
||||
|
||||
def _watch(self, primals, tangents):
|
||||
@ -358,7 +359,7 @@ class ForwardAccumulator(object):
|
||||
# Run convert_to_tensor to get the captured handle from whichever
|
||||
# function we're running if necessary.
|
||||
t = ops.convert_to_tensor(t.handle)
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
|
||||
pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
|
||||
|
||||
def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
|
||||
"""Fetches the Jacobian-vector product computed for `primals`.
|
||||
@ -384,8 +385,8 @@ class ForwardAccumulator(object):
|
||||
def _fetch_jvp(tensor):
|
||||
if hasattr(tensor, "handle"):
|
||||
tensor = ops.convert_to_tensor(tensor.handle)
|
||||
result = pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP(
|
||||
self._accumulator, tensor)
|
||||
result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
|
||||
tensor)
|
||||
if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
|
||||
return array_ops.zeros_like(tensor)
|
||||
return result
|
||||
|
@ -24,7 +24,7 @@ import weakref
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -236,13 +236,13 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
x = constant_op.constant(1.)
|
||||
with forwardprop.ForwardAccumulator(x, 2.) as acc:
|
||||
y = x + x
|
||||
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(
|
||||
pywrap_tfe.TFE_Py_RegisterJVPFunction(
|
||||
lambda *args, **kwargs: [constant_op.constant(-15.)])
|
||||
z = x + x
|
||||
self.assertAllClose(4., acc.jvp(y))
|
||||
self.assertAllClose(-15., acc.jvp(z))
|
||||
finally:
|
||||
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(previous_fn)
|
||||
pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testFunctionCacheLimited(self):
|
||||
@ -738,19 +738,19 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
|
||||
with backprop.GradientTape() as tape:
|
||||
self.assertFalse(tape_lib.should_record_backprop([c]))
|
||||
self.assertEqual(
|
||||
1, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
self.assertEqual(1,
|
||||
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
tape.watch(c)
|
||||
self.assertEqual(
|
||||
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
self.assertEqual(2,
|
||||
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
self.assertTrue(tape_lib.should_record_backprop([c]))
|
||||
with tape_lib.stop_recording():
|
||||
self.assertEqual(
|
||||
0, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
self.assertEqual(0,
|
||||
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
self.assertFalse(tape_lib.should_record_backprop([c]))
|
||||
d = c * 2.
|
||||
self.assertEqual(
|
||||
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
self.assertEqual(2,
|
||||
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
|
||||
self.assertTrue(tape_lib.should_record_backprop([c]))
|
||||
self.assertFalse(tape_lib.should_record_backprop([d]))
|
||||
self.assertIsNone(acc.jvp(d))
|
||||
|
@ -24,7 +24,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
|
||||
|
||||
class TangentInfo(
|
||||
@ -54,8 +54,7 @@ def pack_tangents(tensors):
|
||||
tangents: A flat list of Tensors. Best interpreted as a sequence to be
|
||||
appended to `tensors`.
|
||||
"""
|
||||
return TangentInfo(
|
||||
*pywrap_tensorflow.TFE_Py_PackJVPs(tensors))
|
||||
return TangentInfo(*pywrap_tfe.TFE_Py_PackJVPs(tensors))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -73,7 +72,7 @@ def push_forwardprop_state():
|
||||
None (used for its side effect).
|
||||
"""
|
||||
try:
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState()
|
||||
pywrap_tfe.TFE_Py_ForwardAccumulatorPushState()
|
||||
yield
|
||||
finally:
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState()
|
||||
pywrap_tfe.TFE_Py_ForwardAccumulatorPopState()
|
||||
|
@ -32,8 +32,9 @@ from six.moves import map
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -1098,7 +1099,7 @@ class _TapeGradientFunctions(object):
|
||||
forward_function.signature.name,
|
||||
forward_outputs, forward_inputs, py_backward, None)
|
||||
output_indices, output_tangents = (
|
||||
pywrap_tensorflow.TFE_Py_PackJVPs(forward_outputs))
|
||||
pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
|
||||
output_tangents = [forward_wrapper_graph.capture(t)
|
||||
for t in output_tangents]
|
||||
return _ForwardWrapper(
|
||||
@ -1732,7 +1733,7 @@ class ConcreteFunction(object):
|
||||
"Tensor." % (self._func_graph.name, i, str(arg)))
|
||||
args = tensor_inputs + captured_inputs
|
||||
possible_gradient_type = (
|
||||
pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
|
||||
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
|
||||
if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
|
||||
and executing_eagerly):
|
||||
# No tape is watching; skip to running the function.
|
||||
@ -2552,8 +2553,8 @@ class Function(object):
|
||||
"""Computes the cache key given inputs and execution context."""
|
||||
if self.input_signature is None:
|
||||
inputs = (args, kwargs) if kwargs else args
|
||||
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(
|
||||
inputs, include_tensor_ranks_only)
|
||||
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
|
||||
include_tensor_ranks_only)
|
||||
else:
|
||||
del args, kwargs
|
||||
assert not include_tensor_ranks_only
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
@ -68,7 +68,7 @@ def imperative_grad(tape,
|
||||
raise ValueError(
|
||||
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
|
||||
|
||||
return pywrap_tensorflow.TFE_Py_TapeGradient(
|
||||
return pywrap_tfe.TFE_Py_TapeGradient(
|
||||
tape._tape, # pylint: disable=protected-access
|
||||
target,
|
||||
sources,
|
||||
|
@ -21,80 +21,80 @@ from __future__ import print_function
|
||||
import collections
|
||||
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell')
|
||||
_counter_methods = [
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewCounter0,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter0,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter0),
|
||||
create=pywrap_tfe.TFE_MonitoringNewCounter0,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteCounter0,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter0),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewCounter1,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter1,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter1),
|
||||
create=pywrap_tfe.TFE_MonitoringNewCounter1,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteCounter1,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter1),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewCounter2,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter2,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter2),
|
||||
create=pywrap_tfe.TFE_MonitoringNewCounter2,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteCounter2,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter2),
|
||||
]
|
||||
_int_gauge_methods = [
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge0,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge0,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge0),
|
||||
create=pywrap_tfe.TFE_MonitoringNewIntGauge0,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge0,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge0),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge1,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge1,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge1),
|
||||
create=pywrap_tfe.TFE_MonitoringNewIntGauge1,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge1,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge1),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge2,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge2,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge2),
|
||||
create=pywrap_tfe.TFE_MonitoringNewIntGauge2,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge2,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge2),
|
||||
]
|
||||
_string_gauge_methods = [
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge0,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge0,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge0),
|
||||
create=pywrap_tfe.TFE_MonitoringNewStringGauge0,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge0,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge0),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge1,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge1,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge1),
|
||||
create=pywrap_tfe.TFE_MonitoringNewStringGauge1,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge1,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge1),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge2,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge2,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge2),
|
||||
create=pywrap_tfe.TFE_MonitoringNewStringGauge2,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge2,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge2),
|
||||
]
|
||||
_bool_gauge_methods = [
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge0,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge0,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge0),
|
||||
create=pywrap_tfe.TFE_MonitoringNewBoolGauge0,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge0,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge0),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge1,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge1,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge1),
|
||||
create=pywrap_tfe.TFE_MonitoringNewBoolGauge1,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge1,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge1),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge2,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge2,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge2),
|
||||
create=pywrap_tfe.TFE_MonitoringNewBoolGauge2,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge2,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge2),
|
||||
]
|
||||
_sampler_methods = [
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewSampler0,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler0,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler0),
|
||||
create=pywrap_tfe.TFE_MonitoringNewSampler0,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteSampler0,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler0),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewSampler1,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler1,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler1),
|
||||
create=pywrap_tfe.TFE_MonitoringNewSampler1,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteSampler1,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler1),
|
||||
_MetricMethod(
|
||||
create=pywrap_tensorflow.TFE_MonitoringNewSampler2,
|
||||
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler2,
|
||||
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler2),
|
||||
create=pywrap_tfe.TFE_MonitoringNewSampler2,
|
||||
delete=pywrap_tfe.TFE_MonitoringDeleteSampler2,
|
||||
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler2),
|
||||
]
|
||||
|
||||
|
||||
@ -156,11 +156,11 @@ class CounterCell(object):
|
||||
Args:
|
||||
value: non-negative value.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_MonitoringCounterCellIncrementBy(self._cell, value)
|
||||
pywrap_tfe.TFE_MonitoringCounterCellIncrementBy(self._cell, value)
|
||||
|
||||
def value(self):
|
||||
"""Retrieves the current value."""
|
||||
return pywrap_tensorflow.TFE_MonitoringCounterCellValue(self._cell)
|
||||
return pywrap_tfe.TFE_MonitoringCounterCellValue(self._cell)
|
||||
|
||||
|
||||
class Counter(Metric):
|
||||
@ -204,11 +204,11 @@ class IntGaugeCell(object):
|
||||
Args:
|
||||
value: integer value.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_MonitoringIntGaugeCellSet(self._cell, value)
|
||||
pywrap_tfe.TFE_MonitoringIntGaugeCellSet(self._cell, value)
|
||||
|
||||
def value(self):
|
||||
"""Retrieves the current value."""
|
||||
return pywrap_tensorflow.TFE_MonitoringIntGaugeCellValue(self._cell)
|
||||
return pywrap_tfe.TFE_MonitoringIntGaugeCellValue(self._cell)
|
||||
|
||||
|
||||
class IntGauge(Metric):
|
||||
@ -252,13 +252,13 @@ class StringGaugeCell(object):
|
||||
Args:
|
||||
value: string value.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_MonitoringStringGaugeCellSet(self._cell, value)
|
||||
pywrap_tfe.TFE_MonitoringStringGaugeCellSet(self._cell, value)
|
||||
|
||||
def value(self):
|
||||
"""Retrieves the current value."""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
|
||||
value = pywrap_tensorflow.TF_GetBuffer(buffer_).decode('utf-8')
|
||||
pywrap_tfe.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
|
||||
value = pywrap_tfe.TF_GetBuffer(buffer_).decode('utf-8')
|
||||
return value
|
||||
|
||||
|
||||
@ -303,11 +303,11 @@ class BoolGaugeCell(object):
|
||||
Args:
|
||||
value: bool value.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_MonitoringBoolGaugeCellSet(self._cell, value)
|
||||
pywrap_tfe.TFE_MonitoringBoolGaugeCellSet(self._cell, value)
|
||||
|
||||
def value(self):
|
||||
"""Retrieves the current value."""
|
||||
return pywrap_tensorflow.TFE_MonitoringBoolGaugeCellValue(self._cell)
|
||||
return pywrap_tfe.TFE_MonitoringBoolGaugeCellValue(self._cell)
|
||||
|
||||
|
||||
class BoolGauge(Metric):
|
||||
@ -351,7 +351,7 @@ class SamplerCell(object):
|
||||
Args:
|
||||
value: float value.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_MonitoringSamplerCellAdd(self._cell, value)
|
||||
pywrap_tfe.TFE_MonitoringSamplerCellAdd(self._cell, value)
|
||||
|
||||
def value(self):
|
||||
"""Retrieves the current distribution of samples.
|
||||
@ -360,8 +360,8 @@ class SamplerCell(object):
|
||||
A HistogramProto describing the distribution of samples.
|
||||
"""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
|
||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
|
||||
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
histogram_proto = summary_pb2.HistogramProto()
|
||||
histogram_proto.ParseFromString(compat.as_bytes(proto_data))
|
||||
return histogram_proto
|
||||
@ -379,7 +379,7 @@ class Buckets(object):
|
||||
self.buckets = buckets
|
||||
|
||||
def __del__(self):
|
||||
pywrap_tensorflow.TFE_MonitoringDeleteBuckets(self.buckets)
|
||||
pywrap_tfe.TFE_MonitoringDeleteBuckets(self.buckets)
|
||||
|
||||
|
||||
class ExponentialBuckets(Buckets):
|
||||
@ -399,8 +399,8 @@ class ExponentialBuckets(Buckets):
|
||||
bucket_count: integer
|
||||
"""
|
||||
super(ExponentialBuckets, self).__init__(
|
||||
pywrap_tensorflow.TFE_MonitoringNewExponentialBuckets(
|
||||
scale, growth_factor, bucket_count))
|
||||
pywrap_tfe.TFE_MonitoringNewExponentialBuckets(scale, growth_factor,
|
||||
bucket_count))
|
||||
|
||||
|
||||
class Sampler(Metric):
|
||||
|
@ -39,9 +39,9 @@ import os
|
||||
import threading
|
||||
|
||||
from tensorflow.python import _pywrap_events_writer
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
@ -74,8 +74,8 @@ def start():
|
||||
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.ensure_initialized()
|
||||
_profiler = pywrap_tensorflow.TFE_NewProfiler()
|
||||
if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler):
|
||||
_profiler = pywrap_tfe.TFE_NewProfiler()
|
||||
if not pywrap_tfe.TFE_ProfilerIsOk(_profiler):
|
||||
logging.warning('Another profiler session is running which is probably '
|
||||
'created by profiler server. Please avoid using profiler '
|
||||
'server and profiler APIs at the same time.')
|
||||
@ -100,11 +100,9 @@ def stop():
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.context().executor.wait()
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TFE_ProfilerSerializeToString(
|
||||
_profiler,
|
||||
buffer_)
|
||||
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
pywrap_tensorflow.TFE_DeleteProfiler(_profiler)
|
||||
pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
|
||||
result = pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
pywrap_tfe.TFE_DeleteProfiler(_profiler)
|
||||
_profiler = None
|
||||
_run_num += 1
|
||||
return result
|
||||
@ -159,7 +157,7 @@ def start_profiler_server(port):
|
||||
"""
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.ensure_initialized()
|
||||
pywrap_tensorflow.TFE_StartProfilerServer(port)
|
||||
pywrap_tfe.TFE_StartProfilerServer(port)
|
||||
|
||||
|
||||
class Profiler(object):
|
||||
|
@ -18,8 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import eager_util as c_api_util
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ def start_tracing(service_addr,
|
||||
Raises:
|
||||
UnavailableError: If no trace event is collected.
|
||||
"""
|
||||
if not pywrap_tensorflow.TFE_ProfilerClientStartTracing(
|
||||
if not pywrap_tfe.TFE_ProfilerClientStartTracing(
|
||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts):
|
||||
raise errors.UnavailableError(None, None, 'No trace event is collected.')
|
||||
@ -71,7 +71,7 @@ def monitor(service_addr,
|
||||
A string of monitoring output.
|
||||
"""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms,
|
||||
monitoring_level,
|
||||
display_timestamp, buffer_)
|
||||
return pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms,
|
||||
monitoring_level, display_timestamp,
|
||||
buffer_)
|
||||
return pywrap_tfe.TF_GetBuffer(buffer_)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
@ -54,14 +54,16 @@ class Tests(test.TestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
math_ops.matmul(a_2_by_2, b_2_by_2),
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2,
|
||||
b_2_by_2, "transpose_a", False, "transpose_b", False))
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_2_by_2,
|
||||
b_2_by_2, "transpose_a", False,
|
||||
"transpose_b", False))
|
||||
self.assertAllClose(
|
||||
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_100_by_784,
|
||||
b_100_by_784, "transpose_a", False, "transpose_b", True))
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_100_by_784,
|
||||
b_100_by_784, "transpose_a", False,
|
||||
"transpose_b", True))
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
@ -71,12 +73,14 @@ class Tests(test.TestCase):
|
||||
|
||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||
x = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a",
|
||||
False, "transpose_b", False)
|
||||
y = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, a_2_by_2,
|
||||
"transpose_a", False, "transpose_b", False)
|
||||
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, m, m,
|
||||
"transpose_a", False, "transpose_b",
|
||||
False)
|
||||
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_2_by_2,
|
||||
a_2_by_2, "transpose_a", False,
|
||||
"transpose_b", False)
|
||||
|
||||
self.assertAllEqual(x, y)
|
||||
|
||||
@ -89,9 +93,10 @@ class Tests(test.TestCase):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||
tape.watch(a_2_by_2)
|
||||
z = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2,
|
||||
a_2_by_2, "transpose_a", False, "transpose_b", False)
|
||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, a_2_by_2,
|
||||
a_2_by_2, "transpose_a", False,
|
||||
"transpose_b", False)
|
||||
dz_dy = tape.gradient(z, [a_2_by_2])[0]
|
||||
self.assertAllEqual(dz_dy.numpy(),
|
||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||
@ -106,9 +111,10 @@ class Tests(test.TestCase):
|
||||
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
|
||||
m = resource_variable_ops.ResourceVariable(a_2_by_2)
|
||||
tape.watch(m)
|
||||
z = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "MatMul", None, None, m, m,
|
||||
"transpose_a", False, "transpose_b", False)
|
||||
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, m, m,
|
||||
"transpose_a", False, "transpose_b",
|
||||
False)
|
||||
dz_dy = tape.gradient(z, [m])[0]
|
||||
self.assertAllEqual(dz_dy.numpy(),
|
||||
constant_op.constant(4.0, shape=[2, 2]).numpy())
|
||||
@ -125,9 +131,8 @@ class Tests(test.TestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
math_ops.add_n([a_2_by_2, b_2_by_2]),
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"AddN", None, None,
|
||||
[a_2_by_2, b_2_by_2]))
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN",
|
||||
None, None, [a_2_by_2, b_2_by_2]))
|
||||
|
||||
# Tests homogeneous list op
|
||||
@test_util.assert_no_new_tensors
|
||||
@ -142,9 +147,9 @@ class Tests(test.TestCase):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(a_2_by_2)
|
||||
tape.watch(b_2_by_2)
|
||||
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "AddN", None, None,
|
||||
[a_2_by_2, b_2_by_2])
|
||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"AddN", None, None,
|
||||
[a_2_by_2, b_2_by_2])
|
||||
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
|
||||
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
|
||||
dz2_dy = tape.gradient(z2, [a_2_by_2])[0]
|
||||
@ -162,9 +167,9 @@ class Tests(test.TestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
array_ops.identity_n([a_2_by_2, b_2_by_2]),
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"IdentityN", None, None,
|
||||
[a_2_by_2, b_2_by_2]))
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"IdentityN", None, None,
|
||||
[a_2_by_2, b_2_by_2]))
|
||||
|
||||
# Tests heterogeneous list op
|
||||
@test_util.assert_no_new_tensors
|
||||
@ -179,9 +184,9 @@ class Tests(test.TestCase):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(a_2_by_2)
|
||||
tape.watch(b_2_by_2)
|
||||
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
ctx._handle, ctx.device_name, "IdentityN", None, None,
|
||||
[a_2_by_2, b_2_by_2])
|
||||
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"IdentityN", None, None,
|
||||
[a_2_by_2, b_2_by_2])
|
||||
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
|
||||
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
|
||||
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
|
||||
@ -201,19 +206,18 @@ class Tests(test.TestCase):
|
||||
# Not enough base params
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"at least 5 items in the input tuple"):
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
|
||||
"Identity")
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity")
|
||||
|
||||
# Not enough inputs
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Expected to be at least 6, was 5"):
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx_handle,
|
||||
"Identity", None, [])
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity",
|
||||
None, [])
|
||||
|
||||
# Bad type
|
||||
with self.assertRaisesRegexp(TypeError, "expected a string for op_name"):
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
|
||||
ctx_handle, None, [], a_2_by_2)
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle,
|
||||
None, [], a_2_by_2)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
@ -225,9 +229,9 @@ class Tests(test.TestCase):
|
||||
|
||||
ctx_handle = ctx._handle
|
||||
with self.assertRaises(core._FallbackException):
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
|
||||
"Split", None, None, split_dim,
|
||||
value, "num_split", -1)
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split",
|
||||
None, None, split_dim, value,
|
||||
"num_split", -1)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
@ -266,10 +270,9 @@ class Tests(test.TestCase):
|
||||
ctx = context.context()
|
||||
ctx.ensure_initialized()
|
||||
with self.assertRaises(core._FallbackException):
|
||||
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
|
||||
"MatMul", None, None, m, m,
|
||||
"transpose_a", False,
|
||||
"transpose_b", False)
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul",
|
||||
None, None, m, m, "transpose_a", False,
|
||||
"transpose_b", False)
|
||||
|
||||
def testOpDefDefaultType(self):
|
||||
im = np.random.randint(
|
||||
|
@ -22,7 +22,7 @@ import copy
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
@ -127,7 +127,7 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
|
||||
# Automatically add local job, if not part of the cluster spec.
|
||||
if job_name not in cluster_spec.jobs:
|
||||
local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
|
||||
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
||||
job_def = cluster_def.job.add()
|
||||
job_def.name = job_name
|
||||
# TODO(fishx): Update this to make sure remote worker has valid ip address
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
# There is a circular dependency between this, ops.py, and
|
||||
@ -39,24 +39,23 @@ class Tape(object):
|
||||
self._tape = tape
|
||||
|
||||
def watched_variables(self):
|
||||
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
|
||||
return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
|
||||
|
||||
|
||||
def push_new_tape(persistent=False, watch_accessed_variables=True):
|
||||
"""Pushes a new tape onto the tape stack."""
|
||||
tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
|
||||
watch_accessed_variables)
|
||||
tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
|
||||
return Tape(tape)
|
||||
|
||||
|
||||
def push_tape(tape):
|
||||
"""Pushes an existing tape onto the tape stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
|
||||
pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def watch(tape, tensor):
|
||||
"""Marks this tensor to be watched by the given tape."""
|
||||
pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
|
||||
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def watch_variable(tape, variable):
|
||||
@ -68,7 +67,7 @@ def watch_variable(tape, variable):
|
||||
else:
|
||||
variables = strategy.experimental_local_results(variable)
|
||||
for var in variables:
|
||||
pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
|
||||
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def variable_accessed(variable):
|
||||
@ -84,7 +83,7 @@ def variable_accessed(variable):
|
||||
else:
|
||||
variables = strategy.experimental_local_results(variable)
|
||||
for var in variables:
|
||||
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
|
||||
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||
|
||||
|
||||
def variables_accessed(variables):
|
||||
@ -107,25 +106,25 @@ def variables_accessed(variables):
|
||||
accessed.extend(strategy.experimental_local_results(variable))
|
||||
|
||||
for var in accessed:
|
||||
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
|
||||
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||
|
||||
|
||||
def pop_tape(tape):
|
||||
"""Pops the given tape in the stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
|
||||
pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def stop_recording():
|
||||
"""Stop all gradient recording (backprop and forwardprop)."""
|
||||
is_stopped = pywrap_tensorflow.TFE_Py_TapeSetIsStopped()
|
||||
is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped()
|
||||
try:
|
||||
if not is_stopped:
|
||||
pywrap_tensorflow.TFE_Py_TapeSetStopOnThread()
|
||||
pywrap_tfe.TFE_Py_TapeSetStopOnThread()
|
||||
yield
|
||||
finally:
|
||||
if not is_stopped:
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread()
|
||||
pywrap_tfe.TFE_Py_TapeSetRestartOnThread()
|
||||
|
||||
|
||||
def should_record_backprop(tensors):
|
||||
@ -139,22 +138,23 @@ def should_record_backprop(tensors):
|
||||
Returns:
|
||||
Boolean, whether any tape watches any of `tensors`.
|
||||
"""
|
||||
return pywrap_tensorflow.TFE_Py_TapeSetShouldRecordBackprop(tensors)
|
||||
return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors)
|
||||
|
||||
|
||||
def record_operation(op_type, output_tensors, input_tensors, backward_function,
|
||||
forward_function=None):
|
||||
"""Records the operation on all tapes in the stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRecordOperation(
|
||||
op_type, output_tensors, input_tensors, backward_function,
|
||||
forward_function)
|
||||
pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors,
|
||||
input_tensors, backward_function,
|
||||
forward_function)
|
||||
|
||||
|
||||
def record_operation_backprop_only(op_type, output_tensors, input_tensors,
|
||||
backward_function):
|
||||
"""Records the operation on all backward tapes in the stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationBackprop(
|
||||
op_type, output_tensors, input_tensors, backward_function)
|
||||
pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors,
|
||||
input_tensors,
|
||||
backward_function)
|
||||
|
||||
|
||||
def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
|
||||
@ -174,16 +174,16 @@ def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
|
||||
Typically these will have come from TFE_Py_PackForwardGradients. May be
|
||||
None or an empty sequence if there are no JVP outputs from the operation.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationForwardprop(
|
||||
pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop(
|
||||
op_type, output_tensors, input_tensors, backward_function,
|
||||
forwardprop_output_indices)
|
||||
|
||||
|
||||
def delete_trace(tensor_id):
|
||||
"""Deletes traces for this Tensor from all tapes in the stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id)
|
||||
pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id)
|
||||
|
||||
|
||||
def could_possibly_record():
|
||||
"""Returns True if any tape is active."""
|
||||
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
|
||||
return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
|
||||
|
@ -26,7 +26,7 @@ import unittest
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import test
|
||||
@ -435,14 +435,14 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
|
||||
t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
|
||||
|
||||
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
|
||||
r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
|
||||
self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
|
||||
|
||||
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
|
||||
r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
|
||||
self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
|
||||
|
||||
def testEmptyTensorList(self):
|
||||
a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0)
|
||||
a = pywrap_tfe.TFE_Py_TensorShapeSlice([], 0)
|
||||
self.assertTrue(isinstance(a, ops.EagerTensor))
|
||||
self.assertEqual(0, a.numpy().size)
|
||||
|
||||
@ -452,12 +452,12 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"Expected a list of EagerTensors but element 1 has type \"str\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"Expected a list of EagerTensors but element 0 has type \"int\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([2, t1], 0)
|
||||
|
||||
def testTensorListNotList(self):
|
||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
@ -465,7 +465,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice(t1, -2)
|
||||
|
||||
def testNegativeSliceDim(self):
|
||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
@ -473,7 +473,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"Slice dimension must be non-negative. Got -2"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([t1], -2)
|
||||
|
||||
def testUnicode(self):
|
||||
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
|
||||
@ -493,31 +493,31 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
IndexError,
|
||||
r"Slice dimension \(2\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 0 has rank 2"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([t1], 2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 0 has rank 1"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([t2], 1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 1 has rank 1"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2], 1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 0 has rank 0"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([t3], 0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 2 has rank 0"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
||||
pywrap_tfe.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testTensorDir(self):
|
||||
|
@ -36,7 +36,12 @@ from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
# pywrap_tensorflow must be imported first to avoid profobuf issues.
|
||||
# (b/143110113)
|
||||
# pylint: disable=invalid-import-order,g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python import pywrap_tfe as c_api_new
|
||||
# pylint: enable=invalid-import-order,g-bad-import-order
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
@ -249,7 +254,7 @@ def register_dense_tensor_like_type(tensor_type):
|
||||
|
||||
def uid():
|
||||
"""A unique (within this program execution) integer."""
|
||||
return c_api.TFE_Py_UID()
|
||||
return c_api_new.TFE_Py_UID()
|
||||
|
||||
|
||||
def numpy_text(tensor, is_repr=False):
|
||||
@ -1135,7 +1140,7 @@ class _EagerTensorBase(Tensor):
|
||||
|
||||
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
|
||||
# registers it with the current module.
|
||||
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||
EagerTensor = c_api_new.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||
|
||||
|
||||
register_dense_tensor_like_type(Tensor)
|
||||
|
@ -789,8 +789,7 @@ void GenEagerPythonOp::AddEagerFastPathExecute() {
|
||||
|
||||
strings::StrAppend(&result_, " try:\n");
|
||||
strings::StrAppend(
|
||||
&result_, " ",
|
||||
"_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
|
||||
&result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
|
||||
WordWrap(strings::StrCat(" "),
|
||||
strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
|
||||
"\n");
|
||||
@ -1000,7 +999,7 @@ This file is MACHINE GENERATED! Do not edit.
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe as pywrap_tfe
|
||||
from tensorflow.python.eager import context as _context
|
||||
from tensorflow.python.eager import core as _core
|
||||
from tensorflow.python.eager import execute as _execute
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -115,8 +116,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
||||
non_neg_concat_dim = (
|
||||
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
|
||||
# All inputs are guaranteed to be EagerTensors in eager mode
|
||||
sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
|
||||
non_neg_concat_dim)
|
||||
sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
|
||||
non_neg_concat_dim)
|
||||
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
|
||||
else:
|
||||
if constant_op.is_constant(concat_dim):
|
||||
|
@ -26,7 +26,7 @@ import sys
|
||||
from absl import logging
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -45,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# Register printing to the cell output if we are in a Colab or Jupyter Notebook.
|
||||
try:
|
||||
get_ipython() # Exists in an ipython env like Jupyter or Colab
|
||||
pywrap_tensorflow.TFE_Py_EnableInteractivePythonLogging()
|
||||
pywrap_tfe.TFE_Py_EnableInteractivePythonLogging()
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
|
||||
using tensorflow::uint64;
|
||||
@ -233,7 +234,50 @@ _COPY_TYPEMAPS(unsigned int, mode_t);
|
||||
%define override %enddef
|
||||
#endif
|
||||
|
||||
|
||||
// This was originally included in pywrap_tfe.i, but is used by tf_session.i
|
||||
%include "tensorflow/c/tf_status.h"
|
||||
%include "tensorflow/c/tf_datatype.h"
|
||||
|
||||
%typemap(in) (const void* proto) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = static_cast<void*>(c_string);
|
||||
}
|
||||
|
||||
%typemap(in) int64_t {
|
||||
$1 = PyLong_AsLongLong($input);
|
||||
}
|
||||
|
||||
%typemap(out) TF_DataType {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(out) int64_t {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(out) TF_AttrType {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
|
||||
tmp = 0;
|
||||
$1 = &tmp;
|
||||
}
|
||||
|
||||
%typemap(argout) unsigned char* is_list {
|
||||
if (*$1 == 1) {
|
||||
PyObject* list = PyList_New(1);
|
||||
PyList_SetItem(list, 0, $result);
|
||||
$result = list;
|
||||
}
|
||||
}
|
||||
|
||||
// Typemaps to automatically raise a Python exception from bad output TF_Status.
|
||||
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
|
||||
|
@ -1,515 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
%}
|
||||
|
||||
|
||||
%include "tensorflow/c/tf_datatype.h"
|
||||
%include "tensorflow/c/tf_status.h"
|
||||
|
||||
%ignoreall;
|
||||
|
||||
%rename("%s") TF_SetXlaEnableLazyCompilation;
|
||||
%rename("%s") TF_SetTfXlaCpuGlobalJit;
|
||||
%rename("%s") TF_SetXlaAutoJitMode;
|
||||
%rename("%s") TF_SetXlaConstantFoldingDisabled;
|
||||
%rename("%s") TF_GetXlaConstantFoldingDisabled;
|
||||
%rename("%s") TF_SetXlaMinClusterSize;
|
||||
%rename("%s") TFE_NewContext;
|
||||
%rename("%s") TFE_DeleteContext;
|
||||
%rename("%s") TFE_ContextListDevices;
|
||||
%rename("%s") TFE_ContextAddFunction;
|
||||
%rename("%s") TFE_ContextAddFunctionDef;
|
||||
%rename("%s") TFE_ContextRemoveFunction;
|
||||
%rename("%s") TFE_ContextHasFunction;
|
||||
%rename("%s") TFE_ContextEnableRunMetadata;
|
||||
%rename("%s") TFE_ContextDisableRunMetadata;
|
||||
%rename("%s") TFE_ContextEnableGraphCollection;
|
||||
%rename("%s") TFE_ContextDisableGraphCollection;
|
||||
%rename("%s") TFE_ContextExportRunMetadata;
|
||||
%rename("%s") TFE_ContextClearCaches;
|
||||
%rename("%s") TFE_ContextGetDevicePlacementPolicy;
|
||||
%rename("%s") TFE_ContextGetMirroringPolicy;
|
||||
%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
|
||||
%rename("%s") TFE_ContextSetThreadLocalMirroringPolicy;
|
||||
%rename("%s") TFE_ContextSetServerDef;
|
||||
%rename("%s") TFE_ContextUpdateServerDef;
|
||||
%rename("%s") TFE_ContextCheckAlive;
|
||||
%rename("%s") TFE_NewExecutor;
|
||||
%rename("%s") TFE_DeleteExecutor;
|
||||
%rename("%s") TFE_ExecutorIsAsync;
|
||||
%rename("%s") TFE_ExecutorWaitForAllPendingNodes;
|
||||
%rename("%s") TFE_ExecutorClearError;
|
||||
%rename("%s") TFE_ContextSetExecutorForThread;
|
||||
%rename("%s") TFE_ContextGetExecutorForThread;
|
||||
%rename("%s") TFE_NewProfiler;
|
||||
%rename("%s") TFE_ProfilerIsOk;
|
||||
%rename("%s") TFE_DeleteProfiler;
|
||||
%rename("%s") TFE_ProfilerSerializeToString;
|
||||
%rename("%s") TFE_StartProfilerServer;
|
||||
%rename("%s") TFE_ProfilerClientStartTracing;
|
||||
%rename("%s") TFE_ProfilerClientMonitor;
|
||||
%rename("%s") TFE_OpNameGetAttrType;
|
||||
%rename("%s") TFE_Py_InitEagerTensor;
|
||||
%rename("%s") TFE_Py_SetEagerTensorProfiler;
|
||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||
%rename("%s") TFE_Py_RegisterJVPFunction;
|
||||
%rename("%s") TFE_Py_RegisterGradientFunction;
|
||||
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
|
||||
%rename("%s") TFE_Py_Execute;
|
||||
%rename("%s") TFE_Py_ExecuteCancelable;
|
||||
%rename("%s") TFE_Py_FastPathExecute;
|
||||
%rename("%s") TFE_Py_RecordGradient;
|
||||
%rename("%s") TFE_Py_UID;
|
||||
%rename("%s") TFE_Py_TapeSetNew;
|
||||
%rename("%s") TFE_Py_TapeSetAdd;
|
||||
%rename("%s") TFE_Py_TapeSetRemove;
|
||||
%rename("%s") TFE_Py_TapeSetStopOnThread;
|
||||
%rename("%s") TFE_Py_TapeSetRestartOnThread;
|
||||
%rename("%s") TFE_Py_TapeSetIsStopped;
|
||||
%rename("%s") TFE_Py_TapeSetIsEmpty;
|
||||
%rename("%s") TFE_Py_TapeSetShouldRecordBackprop;
|
||||
%rename("%s") TFE_Py_TapeSetPossibleGradientTypes;
|
||||
%rename("%s") TFE_Py_TapeSetDeleteTrace;
|
||||
%rename("%s") TFE_Py_TapeSetRecordOperation;
|
||||
%rename("%s") TFE_Py_TapeSetRecordOperationBackprop;
|
||||
%rename("%s") TFE_Py_TapeSetRecordOperationForwardprop;
|
||||
%rename("%s") TFE_Py_TapeGradient;
|
||||
%rename("%s") TFE_Py_TapeVariableAccessed;
|
||||
%rename("%s") TFE_Py_TapeWatch;
|
||||
%rename("%s") TFE_Py_TapeWatchVariable;
|
||||
%rename("%s") TFE_Py_TapeWatchedVariables;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorNew;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorSetAdd;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorPushState;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorPopState;
|
||||
%rename("%s") TFE_Py_PackJVPs;
|
||||
%rename("%s") TFE_NewContextOptions;
|
||||
%rename("%s") TFE_ContextOptionsSetConfig;
|
||||
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
||||
%rename("%s") TFE_ContextOptionsSetMirroringPolicy;
|
||||
%rename("%s") TFE_ContextOptionsSetAsync;
|
||||
%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy;
|
||||
%rename("%s") TFE_DeleteContextOptions;
|
||||
%rename("%s") TFE_Py_TensorShapeSlice;
|
||||
%rename("%s") TFE_Py_TensorShapeOnDevice;
|
||||
%rename("%s") TFE_Py_EnableInteractivePythonLogging;
|
||||
%rename("%s") TFE_Py_SetEagerContext;
|
||||
%rename("%s") TFE_ContextStartStep;
|
||||
%rename("%s") TFE_ContextEndStep;
|
||||
%rename("%s") TFE_Py_RegisterVSpace;
|
||||
%rename("%s") TFE_Py_EncodeArg;
|
||||
%rename("%s") TFE_EnableCollectiveOps;
|
||||
%rename("%s") TF_ListPhysicalDevices;
|
||||
%rename("%s") TF_PickUnusedPortOrDie;
|
||||
%rename("%s") TFE_MonitoringCounterCellIncrementBy;
|
||||
%rename("%s") TFE_MonitoringCounterCellValue;
|
||||
%rename("%s") TFE_MonitoringNewCounter0;
|
||||
%rename("%s") TFE_MonitoringDeleteCounter0;
|
||||
%rename("%s") TFE_MonitoringGetCellCounter0;
|
||||
%rename("%s") TFE_MonitoringNewCounter1;
|
||||
%rename("%s") TFE_MonitoringDeleteCounter1;
|
||||
%rename("%s") TFE_MonitoringGetCellCounter1;
|
||||
%rename("%s") TFE_MonitoringNewCounter2;
|
||||
%rename("%s") TFE_MonitoringDeleteCounter2;
|
||||
%rename("%s") TFE_MonitoringGetCellCounter2;
|
||||
%rename("%s") TFE_MonitoringIntGaugeCellSet;
|
||||
%rename("%s") TFE_MonitoringIntGaugeCellValue;
|
||||
%rename("%s") TFE_MonitoringNewIntGauge0;
|
||||
%rename("%s") TFE_MonitoringDeleteIntGauge0;
|
||||
%rename("%s") TFE_MonitoringGetCellIntGauge0;
|
||||
%rename("%s") TFE_MonitoringNewIntGauge1;
|
||||
%rename("%s") TFE_MonitoringDeleteIntGauge1;
|
||||
%rename("%s") TFE_MonitoringGetCellIntGauge1;
|
||||
%rename("%s") TFE_MonitoringNewIntGauge2;
|
||||
%rename("%s") TFE_MonitoringDeleteIntGauge2;
|
||||
%rename("%s") TFE_MonitoringGetCellIntGauge2;
|
||||
%rename("%s") TFE_MonitoringStringGaugeCellSet;
|
||||
%rename("%s") TFE_MonitoringStringGaugeCellValue;
|
||||
%rename("%s") TFE_MonitoringNewStringGauge0;
|
||||
%rename("%s") TFE_MonitoringDeleteStringGauge0;
|
||||
%rename("%s") TFE_MonitoringGetCellStringGauge0;
|
||||
%rename("%s") TFE_MonitoringNewStringGauge1;
|
||||
%rename("%s") TFE_MonitoringDeleteStringGauge1;
|
||||
%rename("%s") TFE_MonitoringGetCellStringGauge1;
|
||||
%rename("%s") TFE_MonitoringNewStringGauge2;
|
||||
%rename("%s") TFE_MonitoringDeleteStringGauge2;
|
||||
%rename("%s") TFE_MonitoringGetCellStringGauge2;
|
||||
%rename("%s") TFE_MonitoringBoolGaugeCellSet;
|
||||
%rename("%s") TFE_MonitoringBoolGaugeCellValue;
|
||||
%rename("%s") TFE_MonitoringNewBoolGauge0;
|
||||
%rename("%s") TFE_MonitoringDeleteBoolGauge0;
|
||||
%rename("%s") TFE_MonitoringGetCellBoolGauge0;
|
||||
%rename("%s") TFE_MonitoringNewBoolGauge1;
|
||||
%rename("%s") TFE_MonitoringDeleteBoolGauge1;
|
||||
%rename("%s") TFE_MonitoringGetCellBoolGauge1;
|
||||
%rename("%s") TFE_MonitoringNewBoolGauge2;
|
||||
%rename("%s") TFE_MonitoringDeleteBoolGauge2;
|
||||
%rename("%s") TFE_MonitoringGetCellBoolGauge2;
|
||||
%rename("%s") TFE_MonitoringSamplerCellAdd;
|
||||
%rename("%s") TFE_MonitoringSamplerCellValue;
|
||||
%rename("%s") TFE_MonitoringNewExponentialBuckets;
|
||||
%rename("%s") TFE_MonitoringDeleteBuckets;
|
||||
%rename("%s") TFE_MonitoringNewSampler0;
|
||||
%rename("%s") TFE_MonitoringDeleteSampler0;
|
||||
%rename("%s") TFE_MonitoringGetCellSampler0;
|
||||
%rename("%s") TFE_MonitoringNewSampler1;
|
||||
%rename("%s") TFE_MonitoringDeleteSampler1;
|
||||
%rename("%s") TFE_MonitoringGetCellSampler1;
|
||||
%rename("%s") TFE_MonitoringNewSampler2;
|
||||
%rename("%s") TFE_MonitoringDeleteSampler2;
|
||||
%rename("%s") TFE_MonitoringGetCellSampler2;
|
||||
%rename("%s") TFE_NewCancellationManager;
|
||||
%rename("%s") TFE_CancellationManagerIsCancelled;
|
||||
%rename("%s") TFE_CancellationManagerStartCancel;
|
||||
%rename("%s") TFE_DeleteCancellationManager;
|
||||
%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||
%rename("%s") TFE_ClearScalarCache;
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
|
||||
static PyObject* TF_ListPhysicalDevices(TF_Status* status) {
|
||||
std::vector<string> devices;
|
||||
tensorflow::Status s = tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
if (!s.ok()) {
|
||||
Py_RETURN_NONE;
|
||||
};
|
||||
PyObject* result = PyList_New(devices.size());
|
||||
int i = 0;
|
||||
for (auto& dev : devices) {
|
||||
PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
|
||||
PyList_SetItem(result, i, dev_obj);
|
||||
++i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
%}
|
||||
static PyObject* TF_ListPhysicalDevices(TF_Status* status);
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
||||
|
||||
static PyObject* TFE_ClearScalarCache() {
|
||||
tensorflow::TFE_TensorHandleCache::Get()->Clear();
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
%}
|
||||
static PyObject* TFE_ClearScalarCache();
|
||||
|
||||
%typemap(in) (const void* proto) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = static_cast<void*>(c_string);
|
||||
}
|
||||
|
||||
%typemap(in) int64_t {
|
||||
$1 = PyLong_AsLongLong($input);
|
||||
}
|
||||
|
||||
%typemap(out) TF_DataType {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(out) int64_t {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(out) TF_AttrType {
|
||||
$result = PyInt_FromLong($1);
|
||||
}
|
||||
|
||||
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
|
||||
tmp = 0;
|
||||
$1 = &tmp;
|
||||
}
|
||||
|
||||
%typemap(argout) unsigned char* is_list {
|
||||
if (*$1 == 1) {
|
||||
PyObject* list = PyList_New(1);
|
||||
PyList_SetItem(list, 0, $result);
|
||||
$result = list;
|
||||
}
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* serialized_function_def {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* device_name {
|
||||
if ($input == Py_None) {
|
||||
$1 = nullptr;
|
||||
} else {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* op_name {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* name {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* description {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* label {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* label1 {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* label2 {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* service_addr {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* logdir {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
// For const parameters in a function, SWIG pretty much ignores the const.
|
||||
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
|
||||
// Hence the 'const_cast'.
|
||||
%typemap(in) const char* worker_list {
|
||||
$1 = const_cast<char*>(TFE_GetPythonString($input));
|
||||
}
|
||||
|
||||
%typemap(in) (TFE_Context*) {
|
||||
$1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr);
|
||||
|
||||
}
|
||||
%typemap(out) (TFE_Context*) {
|
||||
// When the TFE_Context* returned is a nullptr, we expect the status is not
|
||||
// OK. This will raise an error (happens in another typemap).
|
||||
if ($1 != nullptr) {
|
||||
$result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
|
||||
}
|
||||
}
|
||||
|
||||
%rename("%s") TFE_ContextDevicePlacementPolicy;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_WARN;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32;
|
||||
|
||||
%rename("%s") TFE_ContextMirroringPolicy;
|
||||
%rename("%s") TFE_MIRRORING_NONE;
|
||||
%rename("%s") TFE_MIRRORING_ALL;
|
||||
|
||||
%include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
||||
$1 = &temp;
|
||||
if ($input != Py_None) {
|
||||
if (!PyList_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError,
|
||||
"must provide a list of Tensors as inputs");
|
||||
}
|
||||
Py_ssize_t len = PyList_Size($input);
|
||||
$1->resize(len);
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* elem = PyList_GetItem($input, i);
|
||||
if (!elem) {
|
||||
SWIG_fail;
|
||||
}
|
||||
if (EagerTensor_CheckExact(elem)) {
|
||||
(*$1)[i] = EagerTensor_Handle(elem);
|
||||
} else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
|
||||
// Use equivalent of object.__getattribute__ to get the underlying
|
||||
// tf wrapped EagerTensor (if there is one).
|
||||
tensorflow::Safe_PyObjectPtr tf_should_use_attr(
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyString_InternFromString("_tf_should_use_wrapped_value")
|
||||
#else
|
||||
PyUnicode_InternFromString("_tf_should_use_wrapped_value")
|
||||
#endif
|
||||
);
|
||||
tensorflow::Safe_PyObjectPtr value_attr(
|
||||
PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
|
||||
if (value_attr) {
|
||||
// This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
|
||||
(*$1)[i] = EagerTensor_Handle(value_attr.get());
|
||||
} else {
|
||||
// This is a subclass of EagerTensor that we don't support.
|
||||
PyErr_Clear();
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Saw an object that is an instance of a strict subclass of "
|
||||
"EagerTensor, which is not supported. Item ",
|
||||
i, " is type: ", elem->ob_type->tp_name)
|
||||
.c_str());
|
||||
}
|
||||
} else if (tensorflow::swig::IsTensor(elem)) {
|
||||
// If it isnt an EagerTensor, but is still a Tensor, it must be a graph
|
||||
// tensor.
|
||||
tensorflow::Safe_PyObjectPtr name_attr(
|
||||
PyObject_GetAttrString(elem, "name"));
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"An op outside of the function building code is being passed\n"
|
||||
"a \"Graph\" tensor. It is possible to have Graph tensors\n"
|
||||
"leak out of the function building context by including a\n"
|
||||
"tf.init_scope in your function building code.\n"
|
||||
"For example, the following function will fail:\n",
|
||||
" @tf.function\n",
|
||||
" def has_init_scope():\n",
|
||||
" my_constant = tf.constant(1.)\n",
|
||||
" with tf.init_scope():\n",
|
||||
" added = my_constant * 2\n",
|
||||
"The graph tensor has name: ",
|
||||
name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>"
|
||||
).c_str());
|
||||
} else {
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"provided list of inputs contains objects other "
|
||||
"than 'EagerTensor'. Item ",
|
||||
i, " is type: ", elem->ob_type->tp_name).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Temporary for the argout
|
||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp) {
|
||||
if (!PyInt_Check($input)) {
|
||||
SWIG_exception_fail(SWIG_TypeError,
|
||||
"expected an integer value (size of the number of "
|
||||
"outputs of the operation)");
|
||||
}
|
||||
$1 = &temp;
|
||||
long sz = PyInt_AsLong($input);
|
||||
if (sz > 0) {
|
||||
$1->resize(PyInt_AsLong($input), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Create new Status object.
|
||||
%typemap(in, numinputs=0) TF_Status *out_status {
|
||||
$1 = GetStatus();
|
||||
}
|
||||
|
||||
%typemap(freearg) (TF_Status* out_status) {
|
||||
ReturnStatus($1);
|
||||
}
|
||||
|
||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
|
||||
if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
int num_outputs = $1->size();
|
||||
Py_CLEAR($result);
|
||||
$result = PyList_New(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
PyObject *output;
|
||||
output = EagerTensorFromHandle($1->at(i));
|
||||
PyList_SetItem($result, i, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SWIG usually unwraps the tuple that the native Python/C interface generates.
|
||||
// Since we wanted to have a function with a variable length of arguments, we
|
||||
// used the native Python/C interface directly (which by default supports
|
||||
// passing all arguments as a tuple).
|
||||
%native(TFE_Py_FastPathExecute) TFE_Py_FastPathExecute_C;
|
||||
|
||||
%include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
%include "tensorflow/c/c_api_experimental.h"
|
||||
%include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
// Clear all typemaps.
|
||||
%typemap(out) TF_DataType;
|
||||
%typemap(in) int64_t;
|
||||
%typemap(out) int64_t;
|
||||
%typemap(out) TF_AttrType;
|
||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||
%typemap(argout) unsigned char* is_list;
|
||||
%typemap(in) const char* description;
|
||||
%typemap(in) const char* label1;
|
||||
%typemap(in) const char* label2;
|
||||
%typemap(in) (TFE_Context*);
|
||||
%typemap(out) (TFE_Context*);
|
||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
|
||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||
%typemap(freearg) (TF_Status* out_status);
|
||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
|
||||
%typemap(in) (const void* proto);
|
||||
|
||||
%unignoreall
|
29
tensorflow/python/pywrap_tfe.py
Normal file
29
tensorflow/python/pywrap_tfe.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python module for TFE ops and functions exported by pybind11.
|
||||
|
||||
This module is created because we are splitting out eager bindings from
|
||||
pywrap_tensorflow. This is causing some issues where Graphs are not properly
|
||||
initialized when running eager code. Once the graph architecture has been
|
||||
removed from pywrap_tensorflow as well, we can remove this file.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python._pywrap_tfe import *
|
@ -17,8 +17,6 @@ limitations under the License.
|
||||
* The includes are intentionally not alphabetically sorted, as the order of
|
||||
* includes follows dependency order */
|
||||
|
||||
%include "tensorflow/python/pywrap_tfe.i"
|
||||
|
||||
%include "tensorflow/python/client/tf_session.i"
|
||||
|
||||
%include "tensorflow/python/lib/io/file_io.i"
|
||||
|
1099
tensorflow/python/tfe_wrapper.cc
Normal file
1099
tensorflow/python/tfe_wrapper.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@
|
||||
*perftools*gputools*
|
||||
*tf_*
|
||||
*TF_*
|
||||
*Eager*
|
||||
*TFE_*
|
||||
*nsync_*
|
||||
*stream_executor*
|
||||
|
@ -4,6 +4,7 @@ tensorflow {
|
||||
*toco*;
|
||||
*perftools*gputools*;
|
||||
*TF_*;
|
||||
*Eager*;
|
||||
*TFE_*;
|
||||
*nsync_*;
|
||||
*stream_executor*;
|
||||
|
@ -1,4 +1,4 @@
|
||||
[cpp_python_util] # util
|
||||
[cpp_python_util] # util tfe
|
||||
tensorflow::swig::IsSequence
|
||||
tensorflow::swig::IsSequenceOrComposite
|
||||
tensorflow::swig::IsCompositeTensor
|
||||
@ -17,6 +17,7 @@ tensorflow::swig::IsSequenceForData
|
||||
tensorflow::swig::FlattenForData
|
||||
tensorflow::swig::AssertSameStructureForData
|
||||
tensorflow::swig::RegisterType
|
||||
tensorflow::swig::IsEagerTensorSlow
|
||||
|
||||
[util_port] # util_port
|
||||
tensorflow::IsGoogleCudaEnabled
|
||||
@ -74,11 +75,12 @@ tensorflow::Status::code
|
||||
tensorflow::Status::error_message
|
||||
tensorflow::Status::ok()
|
||||
|
||||
[core_cpu_impl] # device_lib
|
||||
[core_cpu_impl] # device_lib tfe
|
||||
tensorflow::Device::attributes
|
||||
tensorflow::DeviceFactory::AddDevices
|
||||
tensorflow::SessionOptions::SessionOptions
|
||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
|
||||
tensorflow::DeviceFactory::ListAllPhysicalDevices
|
||||
|
||||
[protos_all] # device_lib, dtypes
|
||||
tensorflow::DataType_IsValid
|
||||
@ -123,3 +125,67 @@ tensorflow::make_safe
|
||||
|
||||
[python_op_gen] # python_op_gen
|
||||
tensorflow::GetPythonWrappers
|
||||
|
||||
[pywrap_tfe_lib] # tfe
|
||||
tensorflow::TFE_TensorHandleCache
|
||||
tensorflow::TFE_TensorHandleCache::Clear
|
||||
EagerTensor_CheckExact
|
||||
EagerTensorFromHandle
|
||||
EagerTensor_Handle
|
||||
TFE_Py_ExecuteCancelable
|
||||
TFE_Py_RegisterExceptionClass
|
||||
TFE_Py_RegisterVSpace
|
||||
TFE_Py_RegisterFallbackExceptionClass
|
||||
TFE_Py_RegisterGradientFunction
|
||||
TFE_Py_RegisterJVPFunction
|
||||
TFE_GetPythonString
|
||||
TFE_Py_UID
|
||||
TFE_DeleteContextCapsule
|
||||
TFE_Py_InitEagerTensor
|
||||
TFE_Py_SetEagerTensorProfiler
|
||||
TFE_Py_TapeSetNew
|
||||
TFE_Py_TapeSetRemove
|
||||
TFE_Py_TapeSetAdd
|
||||
TFE_Py_TapeSetIsEmpty
|
||||
TFE_Py_TapeSetShouldRecordBackprop
|
||||
TFE_Py_TapeSetPossibleGradientTypes
|
||||
TFE_Py_TapeWatch
|
||||
TFE_Py_TapeSetDeleteTrace
|
||||
TFE_Py_TapeSetStopOnThread
|
||||
TFE_Py_TapeSetRestartOnThread
|
||||
TFE_Py_TapeSetIsStopped
|
||||
TFE_Py_TapeSetRecordOperation
|
||||
TFE_Py_TapeSetRecordOperationBackprop
|
||||
TFE_Py_TapeSetRecordOperationForwardprop
|
||||
TFE_Py_TapeVariableAccessed
|
||||
TFE_Py_TapeWatchVariable
|
||||
TFE_Py_TapeGradient
|
||||
TFE_Py_FastPathExecute_C
|
||||
TFE_Py_RecordGradient
|
||||
TFE_Py_TapeWatchedVariables
|
||||
TFE_Py_ForwardAccumulatorNew
|
||||
TFE_Py_ForwardAccumulatorSetAdd
|
||||
TFE_Py_ForwardAccumulatorSetRemove
|
||||
TFE_Py_ForwardAccumulatorWatch
|
||||
TFE_Py_ForwardAccumulatorJVP
|
||||
TFE_Py_ForwardAccumulatorPushState
|
||||
TFE_Py_ForwardAccumulatorPopState
|
||||
TFE_Py_PackJVPs
|
||||
TFE_Py_TensorShapeSlice
|
||||
TFE_Py_TensorShapeOnDevice
|
||||
TFE_Py_EncodeArg
|
||||
TFE_Py_EnableInteractivePythonLogging
|
||||
TFE_Py_SetEagerContext
|
||||
|
||||
[eager_executor] # tfe
|
||||
tensorflow::EagerExecutor::~EagerExecutor
|
||||
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||
|
||||
[profiler_session] # tfe
|
||||
tensorflow::ProfilerSession::~ProfilerSession
|
||||
|
||||
[tf_status_helper] # tfe
|
||||
tensorflow::Set_TF_Status_from_Status
|
||||
|
||||
[context] # tfe
|
||||
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||
|
Loading…
Reference in New Issue
Block a user