Export the Eager classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 286110711
Change-Id: I7bf6f6f4ce1d6bf3e8a3e40ef4a83f82333800f6
This commit is contained in:
Amit Patankar 2019-12-17 19:37:26 -08:00 committed by TensorFlower Gardener
parent 03341c4342
commit 7bd345bcbb
47 changed files with 1866 additions and 853 deletions

View File

@ -53,6 +53,20 @@ filegroup(
visibility = ["//visibility:public"],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"c_api_internal.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library(
name = "c_api_internal",
hdrs = [

View File

@ -88,6 +88,18 @@ tf_cuda_library(
alwayslink = 1,
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api_experimental.h"],

View File

@ -439,6 +439,23 @@ tf_cc_test(
],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"attr_builder.h",
"context.h",
"eager_executor.h",
"eager_operation.h",
"kernel_and_device.h",
"tensor_handle.h",
"tensor_handle_data.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
filegroup(
name = "srcs",
srcs = glob(

View File

@ -783,3 +783,20 @@ tf_cc_test(
"//tensorflow/core:worker_proto_cc",
],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"call_options.h",
"message_wrappers.h",
"rendezvous_mgr_interface.h",
"server_lib.h",
"worker_cache.h",
"worker_env.h",
"worker_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -216,3 +216,16 @@ cc_library(
"@com_google_absl//absl/types:optional",
],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"eager_client.h",
"remote_tensor_handle.h",
"remote_tensor_handle_data.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -853,6 +853,18 @@ tf_cc_tests(
],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"op_gen_lib.h",
"rendezvous.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
# All framewrok protos are self-contained, i.e. they only import other
# protos from the same package, so we can build the protos here and then
# link them from core:protos_all without circular dependencies.

View File

@ -523,3 +523,14 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"profiler_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -43,6 +43,17 @@ tf_cuda_library(
alwayslink = True,
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"profiler_session.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "traceme",
hdrs = ["traceme.h"],

View File

@ -171,6 +171,7 @@ py_library(
":platform",
":proto_ops",
":pywrap_tensorflow",
":pywrap_tfe",
":rnn_ops_gen",
":saver_test_utils",
":script_ops",
@ -251,6 +252,7 @@ py_library(
deps = [
":_pywrap_util_port",
":lib",
":pywrap_tfe",
":util",
"//tensorflow/core:protos_all_py",
"@absl_py//absl:app",
@ -477,13 +479,13 @@ cc_library(
cc_library(
name = "pybind11_status",
hdrs = [
"lib/core/py_exception_registry.h",
"lib/core/pybind11_status.h",
"//tensorflow/c:headers",
],
features = ["-parse_headers"],
visibility = tf_external_workspace_visible(visibility),
deps = [
":py_exception_registry",
"//tensorflow/c:tf_status_headers",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -1110,6 +1112,7 @@ py_library(
":lib",
":platform",
":pywrap_tensorflow",
":pywrap_tfe",
":random_seed",
":sparse_tensor",
":tensor_spec",
@ -5492,7 +5495,6 @@ tf_py_wrap_cc(
"lib/io/py_record_reader.i",
"lib/io/py_record_writer.i",
"platform/base.i",
"pywrap_tfe.i",
"//tensorflow/compiler/mlir/python:mlir.i",
],
# add win_def_file for pywrap_tensorflow
@ -5573,7 +5575,12 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
":safe_ptr", # checkpoint_reader
":python_op_gen", # python_op_gen
":bfloat16_lib", # bfloat16
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
"//tensorflow/core/common_runtime/eager:context", # tfe
"//tensorflow/core/profiler/lib:profiler_session", # tfe
"//tensorflow/c:tf_status_helper", # tfe
]
# Filter the DEF file to reduce the number of symbols to 64K or less.
@ -7555,6 +7562,67 @@ py_library(
],
)
py_library(
name = "pywrap_tfe",
srcs = ["pywrap_tfe.py"],
visibility = ["//visibility:public"],
deps = [
":_pywrap_tfe",
":pywrap_tensorflow",
],
)
tf_python_pybind_extension(
name = "_pywrap_tfe",
srcs = ["tfe_wrapper.cc"],
hdrs = [
"lib/core/safe_ptr.h",
"util/util.h",
":py_exception_registry_hdr",
"//tensorflow/c:headers",
"//tensorflow/c:pywrap_eager_hdrs",
"//tensorflow/c/eager:headers",
"//tensorflow/c/eager:pywrap_eager_hdrs",
"//tensorflow/core/common_runtime/eager:pywrap_eager_hdrs",
"//tensorflow/core/distributed_runtime:pywrap_eager_hdrs",
"//tensorflow/core/distributed_runtime/eager:pywrap_eager_hdrs",
"//tensorflow/core/framework:pywrap_eager_hdrs",
"//tensorflow/core/profiler/internal:pywrap_eager_hdrs",
"//tensorflow/core/profiler/lib:pywrap_eager_hdrs",
"//tensorflow/python/eager:pywrap_eager_hdrs",
],
module_name = "_pywrap_tfe",
deps = [
":pybind11_lib",
":pybind11_status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@pybind11",
"//third_party/python_runtime:headers",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:platform",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
] + if_static(
extra_deps = [
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:worker_proto_cc",
],
otherwise = [
"//tensorflow/core:eager_service_proto_cc_headers_only",
"//tensorflow/core:master_proto_cc_headers_only",
"//tensorflow/core:worker_proto_cc_headers_only",
],
),
)
tf_python_pybind_extension(
name = "_pywrap_graph_analyzer",
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
@ -23,6 +24,13 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
// We were getting lucky on imports with safe_ptr.h being placed prior to
// tf_session which imported safe_ptr. We also need pywrap_tfe.h to cast
// one of the inputs to a graph function from a Python string to const char*.
// Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs.
@ -78,6 +86,9 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
%}
%include "tensorflow/c/tf_datatype.h"
%include "tensorflow/c/tf_status.h"
%include "tensorflow/python/client/tf_sessionrun_wrapper.i"
// Required to use PyArray_* functions.
@ -85,6 +96,14 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
tensorflow::ImportNumpy();
%}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* op_name {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// TensorFlow version and GraphDef versions
%constant const char* __version__ = TF_VERSION_STRING;
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
@ -174,6 +193,12 @@ tensorflow::ImportNumpy();
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_OperationGetControlInputs_wrapper;
// Migrate one function from pywrap_tfe.i
%include "tensorflow/c/c_api_experimental.h"
%unignore TF_ImportGraphDefOptionsSetValidateColocationConstraints;
%noexception TF_ImportGraphDefOptionsSetValidateColocationConstraints;
// Build a Python list of TF_Operation* and return it.
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
$result = PyList_New($1.size());

View File

@ -268,7 +268,7 @@ py_library(
"//tensorflow/python:device",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:tensor_util",
"//tensorflow/python:training",

View File

@ -24,7 +24,7 @@ import functools
import threading
import weakref
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
@ -944,7 +944,7 @@ class _MirroredReplicaThread(threading.Thread):
self.record_thread_local_summary_state()
self.record_thread_local_eager_context_state()
self.context_device_policy = (
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
ctx._context_handle)) # pylint: disable=protected-access
self.graph = ops.get_default_graph()
with ops.init_scope():

View File

@ -56,6 +56,18 @@ cc_library(
],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"pywrap_tensor_conversion.h",
"pywrap_tfe.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
# Transitive dependencies of this target will be included in the pip package.
py_library(
name = "eager_pip",
@ -90,7 +102,7 @@ py_library(
deps = [
":context",
"//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
],
)
@ -100,7 +112,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
],
)
@ -121,7 +133,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
],
)
@ -131,13 +143,14 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":eager_util",
":executor",
":monitoring",
"//tensorflow/python:device",
"//tensorflow/python:device_spec",
"//tensorflow/python:errors",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tf2",
"//tensorflow/python:util",
"//third_party/py/numpy",
@ -164,8 +177,8 @@ py_library(
"//third_party/py/tf_agents:__subpackages__",
],
deps = [
"//tensorflow/python:c_api_util",
"//tensorflow/python:pywrap_tensorflow",
":eager_util",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
],
)
@ -187,7 +200,8 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":context",
"//tensorflow/python:pywrap_tensorflow",
":eager_util",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
],
)
@ -209,7 +223,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
":eager_util",
"//tensorflow/python:pywrap_tfe",
],
)
@ -298,7 +313,7 @@ cuda_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//third_party/py/numpy",
],
)
@ -410,7 +425,7 @@ py_library(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"@six_archive//:six",
@ -496,7 +511,7 @@ py_library(
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util",
@ -524,7 +539,7 @@ py_library(
deps = [
":forwardprop_util",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
],
)
@ -535,7 +550,18 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
],
)
py_library(
name = "eager_util",
srcs = ["eager_util.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
],
)
@ -552,7 +578,7 @@ cuda_py_test(
":remote",
":test",
"//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:random_ops",
"//tensorflow/python/keras",
"//third_party/py/numpy",
@ -637,7 +663,7 @@ tf_py_test(
":test",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:random_ops",
"//tensorflow/python:test_ops",
"//third_party/py/numpy",
@ -649,7 +675,7 @@ py_library(
srcs = ["imperative_grad.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util",
],

View File

@ -24,7 +24,7 @@ import sys
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python import _pywrap_utils
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
@ -71,19 +71,25 @@ def op_attr_type(op_type, attr_name):
except KeyError:
context.ensure_initialized()
h = context.context()._handle # pylint: disable=protected-access
attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name)
attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
_op_attr_type_cache[(op_type, attr_name)] = attr_type
return attr_type
def make_attr(attr_type, value):
if attr_type == pywrap_tensorflow.TF_ATTR_TYPE:
# pybind11 enums do not return the raw value like SWIG enums do. They are
# useful when comparing amongst each other but not direct integers as we are
# doing in most tests.
# https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
# TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
# from integer value to class.
if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
return dtypes.as_dtype(value)
elif attr_type == [pywrap_tensorflow.TF_ATTR_TYPE]:
elif attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
return [dtypes.as_dtype(v) for v in value]
elif attr_type == pywrap_tensorflow.TF_ATTR_SHAPE:
elif attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
return tensor_shape.as_shape(value).as_proto()
elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]:
elif attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
return [tensor_shape.as_shape(v).as_proto() for v in value]
elif isinstance(value, str):
return value.encode()
@ -141,16 +147,15 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
return grad_fn(mock_op, *out_grads)
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
def _must_record_gradient():
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
def _record_gradient(op_name, inputs, attrs, results):
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
results)
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
execute.must_record_gradient = _must_record_gradient
@ -688,7 +693,7 @@ _default_vspace = imperative_grad.VSpace(
zeros_like_fn=default_gradient.zeros_like,
ones_like_fn=default_gradient.ones_like,
graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):

View File

@ -21,7 +21,7 @@ import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -1014,19 +1014,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def testGetAttrType(self):
typ = backprop.op_attr_type('Add', 'T')
self.assertEqual(typ, pywrap_tensorflow.TF_ATTR_TYPE)
self.assertEqual(typ, int(pywrap_tfe.TF_ATTR_TYPE))
def testGetAttrList(self):
typ = backprop.op_attr_type('MaxPool', 'ksize')
self.assertEqual(typ, [pywrap_tensorflow.TF_ATTR_INT])
self.assertEqual(typ, [int(pywrap_tfe.TF_ATTR_INT)])
def testMakeAttrType(self):
self.assertEqual(dtypes.float32,
backprop.make_attr(pywrap_tensorflow.TF_ATTR_TYPE, 1))
backprop.make_attr(int(pywrap_tfe.TF_ATTR_TYPE), 1))
def testMakeAttrTypeList(self):
self.assertEqual([dtypes.float32],
backprop.make_attr([pywrap_tensorflow.TF_ATTR_TYPE], [1]))
backprop.make_attr([int(pywrap_tfe.TF_ATTR_TYPE)], [1]))
def testMulType(self):
@ -1040,7 +1040,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def testMakeAttrShape(self):
for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
expected = tensor_shape.TensorShape(s).as_proto()
actual = backprop.make_attr(pywrap_tensorflow.TF_ATTR_SHAPE, s)
actual = backprop.make_attr(int(pywrap_tfe.TF_ATTR_SHAPE), s)
self.assertEqual(
expected,
actual,
@ -1051,7 +1051,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
self.assertEqual(
[tensor_shape.TensorShape(s).as_proto() for s in shape_list],
backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list))
backprop.make_attr([int(pywrap_tfe.TF_ATTR_SHAPE)], shape_list))
def testArgsGradientFunction(self):

View File

@ -39,7 +39,7 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python import keras
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop # pylint: disable=unused-import
from tensorflow.python.eager import context
@ -76,10 +76,10 @@ def c_tfe_py_fastpath_execute(a,
assert ctx.executing_eagerly(
), "The prototype doesn't contain C code for graph construction"
try:
return pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "MatMul", name,
ctx.op_callbacks, a, b, "transpose_a", transpose_a,
"transpose_b", transpose_b)
return pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", name, ctx.op_callbacks,
a, b, "transpose_a", transpose_a,
"transpose_b", transpose_b)
except core._NotOkStatusException as e:
if name is not None:
message = e.message + " name: " + name
@ -339,8 +339,7 @@ class MicroBenchmarks(test.Benchmark):
inputs = [m]
def f():
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "Identity", inputs,
attrs, 1)
pywrap_tfe.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, attrs, 1)
self._run(f, 30000)
@ -406,8 +405,7 @@ class MicroBenchmarks(test.Benchmark):
m.dtype.as_datatype_enum)
def func():
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs,
attrs, 1)
pywrap_tfe.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, attrs, 1)
self._run(func, num_iters)

View File

@ -18,27 +18,27 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
class CancellationManager(object):
"""A mechanism for cancelling blocking computation."""
def __init__(self):
self._impl = pywrap_tensorflow.TFE_NewCancellationManager()
self._impl = pywrap_tfe.TFE_NewCancellationManager()
@property
def is_cancelled(self):
"""Returns `True` if `CancellationManager.start_cancel` has been called."""
return pywrap_tensorflow.TFE_CancellationManagerIsCancelled(self._impl)
return pywrap_tfe.TFE_CancellationManagerIsCancelled(self._impl)
def start_cancel(self):
"""Cancels blocking operations that have been registered with this object."""
pywrap_tensorflow.TFE_CancellationManagerStartCancel(self._impl)
pywrap_tfe.TFE_CancellationManagerStartCancel(self._impl)
def get_cancelable_function(self, concrete_function):
# pylint: disable=protected-access
return concrete_function._experimental_with_cancellation_manager(self)
def __del__(self):
pywrap_tensorflow.TFE_DeleteCancellationManager(self._impl)
pywrap_tfe.TFE_DeleteCancellationManager(self._impl)

View File

@ -29,11 +29,11 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python import tf2
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.eager import executor
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
@ -54,17 +54,17 @@ _starting_device_spec = pydev.DeviceSpec.from_string("")
_MAXINT32 = 2**31 - 1
DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
SYNC = 0
ASYNC = 1
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
MIRRORING_NONE = pywrap_tfe.TFE_MIRRORING_NONE
MIRRORING_ALL = pywrap_tfe.TFE_MIRRORING_ALL
_KEEP_ALIVE_SECS = 600
@ -444,7 +444,7 @@ class Context(object):
self._rng = random.Random(seed)
# Also clear the kernel cache, to reset any existing seeds
if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
def _internal_operation_seed(self):
"""Returns a fake operation seed.
@ -463,12 +463,11 @@ class Context(object):
# Store list of devices
logical_devices = []
context_devices = []
device_list = pywrap_tensorflow.TFE_ContextListDevices(
self._context_handle)
device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
try:
self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
context_devices.append(pydev.canonical_name(dev_name))
spec = pydev.DeviceSpec.from_string(dev_name)
# If the job is localhost, we assume that the cluster has not yet been
@ -477,14 +476,14 @@ class Context(object):
spec = spec.replace(job=None, replica=None, task=None)
logical_devices.append(
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
if dev_type == "GPU":
self._num_gpus += 1
finally:
self._logical_devices = logical_devices
self._context_devices = context_devices
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
pywrap_tfe.TF_DeleteDeviceList(device_list)
def ensure_initialized(self):
"""Initialize handle and devices if not already done so."""
@ -494,36 +493,34 @@ class Context(object):
if self._initialized:
return
assert self._context_devices is None
opts = pywrap_tensorflow.TFE_NewContextOptions()
opts = pywrap_tfe.TFE_NewContextOptions()
try:
config_str = self.config.SerializeToString()
pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
if self._device_policy is not None:
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
opts, self._device_policy)
if self._mirroring_policy is not None:
pywrap_tensorflow.TFE_ContextOptionsSetMirroringPolicy(
pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
opts, self._mirroring_policy)
if self._default_is_async == ASYNC:
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
if self._lazy_remote_inputs_copy is not None:
pywrap_tensorflow.TFE_ContextOptionsSetLazyRemoteInputsCopy(
pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy(
opts, self._lazy_remote_inputs_copy)
context_handle = pywrap_tensorflow.TFE_NewContext(opts)
context_handle = pywrap_tfe.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
pywrap_tfe.TFE_DeleteContextOptions(opts)
assert not (self._server_def and self._collective_ops_server_def), (
"Cannot enable remote execution as well as collective ops at the "
"moment. If this is important to you, please file an issue.")
if self._server_def is not None:
server_def_str = self._server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextSetServerDef(context_handle,
_KEEP_ALIVE_SECS,
server_def_str)
pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
server_def_str)
elif self._collective_ops_server_def is not None:
server_def_str = self._collective_ops_server_def.SerializeToString()
pywrap_tensorflow.TFE_EnableCollectiveOps(context_handle,
server_def_str)
pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
self._context_handle = context_handle
self._initialize_logical_devices()
@ -532,7 +529,7 @@ class Context(object):
def _clear_caches(self):
self.ones_rank_cache().flush()
self.zeros_cache().flush()
pywrap_tensorflow.TFE_ClearScalarCache()
pywrap_tfe.TFE_ClearScalarCache()
def get_server_def(self):
return self._server_def
@ -563,8 +560,8 @@ class Context(object):
if self._context_handle:
server_def_str = server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
keep_alive_secs, server_def_str)
pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
server_def_str)
self._initialize_logical_devices()
# Clear all the caches in case there are remote tensors in them.
@ -592,9 +589,8 @@ class Context(object):
if self._context_handle:
server_def_str = server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextUpdateServerDef(self._context_handle,
keep_alive_secs,
server_def_str)
pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
keep_alive_secs, server_def_str)
self._initialize_logical_devices()
self._clear_caches()
@ -614,8 +610,7 @@ class Context(object):
"""
# TODO(yuefengz): support checking multiple workers.
if self._context_handle:
return pywrap_tensorflow.TFE_ContextCheckAlive(self._context_handle,
worker_name)
return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
else:
raise ValueError("Context is not initialized.")
@ -808,8 +803,8 @@ class Context(object):
self.executor.wait()
executor_new = executor.new_executor(enable_async)
self._thread_local_data.executor = executor_new
pywrap_tensorflow.TFE_ContextSetExecutorForThread(
self._context_handle, executor_new.handle())
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
executor_new.handle())
else:
self._default_is_async = enable_async
@ -823,13 +818,12 @@ class Context(object):
def executor(self):
ensure_initialized()
return executor.Executor(
pywrap_tensorflow.TFE_ContextGetExecutorForThread(self._context_handle))
pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
@executor.setter
def executor(self, e):
ensure_initialized()
pywrap_tensorflow.TFE_ContextSetExecutorForThread(self._context_handle,
e.handle())
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
@property
def config(self):
@ -1015,7 +1009,7 @@ class Context(object):
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
"""
self.ensure_initialized()
pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)
pywrap_tfe.TFE_ContextAddFunction(self._handle, fn)
def add_function_def(self, fdef):
"""Add a function definition to the context.
@ -1028,8 +1022,8 @@ class Context(object):
"""
self.ensure_initialized()
fdef_string = fdef.SerializeToString()
pywrap_tensorflow.TFE_ContextAddFunctionDef(
self._handle, fdef_string, len(fdef_string))
pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
len(fdef_string))
def remove_function(self, name):
"""Remove a function from the context.
@ -1040,12 +1034,12 @@ class Context(object):
name: function signature name.
"""
self.ensure_initialized()
pywrap_tensorflow.TFE_ContextRemoveFunction(self._handle, name)
pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
def has_function(self, name):
"""Check if a function `name` is registered."""
self.ensure_initialized()
return bool(pywrap_tensorflow.TFE_ContextHasFunction(self._handle, name))
return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
def add_op_callback(self, callback):
"""Add a post-op callback to the context.
@ -1101,7 +1095,7 @@ class Context(object):
if self._physical_devices is not None:
return
devs = pywrap_tensorflow.TF_ListPhysicalDevices()
devs = pywrap_tfe.TF_ListPhysicalDevices()
self._physical_devices = [
PhysicalDevice(name=d.decode(),
device_type=d.decode().split(":")[1]) for d in devs]
@ -1434,7 +1428,7 @@ class Context(object):
def device_policy(self):
# Only get the policy from the context if it has already been initialized
if self._context_handle is not None:
return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(self._handle)
return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
return self._device_policy
@ -1448,14 +1442,14 @@ class Context(object):
# Only set the policy if the context has already been initialized
if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
self._handle, self._device_policy)
@property
def mirroring_policy(self):
# Only get the policy from the context if it has already been initialized
if self._context_handle is not None:
return pywrap_tensorflow.TFE_ContextGetMirroringPolicy(self._handle)
return pywrap_tfe.TFE_ContextGetMirroringPolicy(self._handle)
return self._mirroring_policy
@ -1469,7 +1463,7 @@ class Context(object):
# Only set the policy if the context has already been initialized
if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextSetThreadLocalMirroringPolicy(
pywrap_tfe.TFE_ContextSetThreadLocalMirroringPolicy(
self._handle, self._mirroring_policy)
@property
@ -1495,13 +1489,13 @@ class Context(object):
and to stop tracing call context.disable_run_metadata().
"""
self.ensure_initialized()
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
def disable_run_metadata(self):
"""Disables tracing of op execution via RunMetadata."""
if not self._context_handle:
return
pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
def enable_graph_collection(self):
"""Enables graph collection of executed functions.
@ -1510,13 +1504,13 @@ class Context(object):
and to stop collecting graphs call context.disable_graph_collection().
"""
self.ensure_initialized()
pywrap_tensorflow.TFE_ContextEnableGraphCollection(self._handle)
pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
def disable_graph_collection(self):
"""Disables graph collection of executed functions."""
if not self._context_handle:
return
pywrap_tensorflow.TFE_ContextDisableGraphCollection(self._context_handle)
pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
def export_run_metadata(self):
"""Returns a RunMetadata proto with accumulated information.
@ -1530,9 +1524,8 @@ class Context(object):
if not self._context_handle:
return None
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ContextExportRunMetadata(
self._context_handle, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(compat.as_bytes(proto_data))
return run_metadata
@ -1543,10 +1536,10 @@ class Context(object):
return self._context_switches
def start_step(self):
pywrap_tensorflow.TFE_ContextStartStep(self._handle)
pywrap_tfe.TFE_ContextStartStep(self._handle)
def end_step(self):
pywrap_tensorflow.TFE_ContextEndStep(self._handle)
pywrap_tfe.TFE_ContextEndStep(self._handle)
class _EagerDeviceContext(object):
@ -1608,7 +1601,7 @@ _context_lock = threading.Lock()
def _set_context_locked(ctx):
global _context
pywrap_tensorflow.TFE_Py_SetEagerContext(ctx)
pywrap_tfe.TFE_Py_SetEagerContext(ctx)
_context = ctx

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.framework import errors
# Trace of execution and memory usage.
@ -46,7 +46,7 @@ class _NotOkStatusException(Exception):
return "%s: %s" % (e.__class__.__name__, e)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
pywrap_tfe.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
class _FallbackException(Exception):
@ -71,4 +71,4 @@ class _SymbolicException(Exception):
pass
pywrap_tensorflow.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)
pywrap_tfe.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)

View File

@ -26,7 +26,7 @@ import threading
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import def_function
@ -602,8 +602,8 @@ class TFETest(test_util.TensorFlowTestCase):
def testRegisterExceptionClass(self):
with self.assertRaises(TypeError):
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
pywrap_tfe.TFE_Py_RegisterExceptionClass(str)
pywrap_tfe.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
# TODO(agarwal): add tests passing incorrect typed values to attrs.
def testExecuteBasic(self):

View File

@ -0,0 +1,61 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for using the TensorFlow Eager using the C API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
# We temporarily need a duplicate tf_buffer function in eager_util. The
# c_api_util is still relying on SWIG and is thus incompatible until
# we migrate over. We can delete this once we migrate tf_session.i
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
Example usage:
with tf_buffer() as buf:
# get serialized graph def into buf
...
proto_data = c_api.TF_GetBuffer(buf)
graph_def.ParseFromString(compat.as_bytes(proto_data))
# buf has been deleted
with tf_buffer(some_string) as buf:
c_api.TF_SomeFunction(buf)
# buf has been deleted
Args:
data: An optional `bytes`, `str`, or `unicode` object. If not None, the
yielded buffer will contain this data.
Yields:
Created TF_Buffer
"""
if data:
buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
else:
buf = c_api.TF_NewBuffer()
try:
yield buf
finally:
c_api.TF_DeleteBuffer(buf)

View File

@ -22,7 +22,7 @@ import six
from google.protobuf import text_format
from tensorflow.core.framework import tensor_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -56,9 +56,8 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
# pylint: disable=protected-access
try:
ctx.ensure_initialized()
tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
op_name, inputs, attrs,
num_outputs)
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
inputs, attrs, num_outputs)
except core._NotOkStatusException as e:
if name is not None:
message = e.message + " name: " + name
@ -111,9 +110,10 @@ def execute_with_cancellation(op_name,
# pylint: disable=protected-access
try:
ctx.ensure_initialized()
tensors = pywrap_tensorflow.TFE_Py_ExecuteCancelable(
ctx._handle, device_name, op_name, inputs, attrs,
cancellation_manager._impl, num_outputs)
tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
op_name, inputs, attrs,
cancellation_manager._impl,
num_outputs)
except core._NotOkStatusException as e:
if name is not None:
message = e.message + " name: " + name

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
class Executor(object):
@ -45,8 +45,8 @@ class Executor(object):
def __del__(self):
try:
# pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
pywrap_tensorflow.TFE_DeleteExecutor(self._handle)
# pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
pywrap_tfe.TFE_DeleteExecutor(self._handle)
except TypeError:
# Suppress some exceptions, mainly for the case when we're running on
# module deletion. Things that can go wrong include the pywrap module
@ -57,20 +57,20 @@ class Executor(object):
# partially unloaded.
def is_async(self):
return pywrap_tensorflow.TFE_ExecutorIsAsync(self._handle)
return pywrap_tfe.TFE_ExecutorIsAsync(self._handle)
def handle(self):
return self._handle
def wait(self):
"""Waits for ops dispatched in this executor to finish."""
pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
def clear_error(self):
"""Clears errors raised in this executor during execution."""
pywrap_tensorflow.TFE_ExecutorClearError(self._handle)
pywrap_tfe.TFE_ExecutorClearError(self._handle)
def new_executor(enable_async):
handle = pywrap_tensorflow.TFE_NewExecutor(enable_async)
handle = pywrap_tfe.TFE_NewExecutor(enable_async)
return Executor(handle)

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import threading
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import def_function
@ -166,7 +166,8 @@ def _jvp_dispatch(op_name, attr_tuple, inputs, outputs, tangents):
return _jvp_relaxed_shapes(
op_name, attr_tuple, inputs, outputs, tangents)
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
@tf_export("autodiff.ForwardAccumulator", v1=[])
@ -300,7 +301,7 @@ class ForwardAccumulator(object):
ValueError: If the same tensor or variable is specified multiple times in
`primals`.
"""
self._accumulator = pywrap_tensorflow.TFE_Py_ForwardAccumulatorNew()
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
self._recording = False
primal_ids = set()
for primal in nest.flatten(primals):
@ -323,13 +324,13 @@ class ForwardAccumulator(object):
def _push_accumulator(self):
if self._recording:
raise ValueError("Accumulator is already recording.")
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
self._recording = True
def _pop_accumulator(self):
if not self._recording:
raise ValueError("Accumulator is not recording.")
pywrap_tensorflow.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
self._recording = False
def _watch(self, primals, tangents):
@ -358,7 +359,7 @@ class ForwardAccumulator(object):
# Run convert_to_tensor to get the captured handle from whichever
# function we're running if necessary.
t = ops.convert_to_tensor(t.handle)
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
"""Fetches the Jacobian-vector product computed for `primals`.
@ -384,8 +385,8 @@ class ForwardAccumulator(object):
def _fetch_jvp(tensor):
if hasattr(tensor, "handle"):
tensor = ops.convert_to_tensor(tensor.handle)
result = pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP(
self._accumulator, tensor)
result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
tensor)
if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
return array_ops.zeros_like(tensor)
return result

View File

@ -24,7 +24,7 @@ import weakref
from absl.testing import parameterized
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
@ -236,13 +236,13 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
x = constant_op.constant(1.)
with forwardprop.ForwardAccumulator(x, 2.) as acc:
y = x + x
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(
pywrap_tfe.TFE_Py_RegisterJVPFunction(
lambda *args, **kwargs: [constant_op.constant(-15.)])
z = x + x
self.assertAllClose(4., acc.jvp(y))
self.assertAllClose(-15., acc.jvp(z))
finally:
pywrap_tensorflow.TFE_Py_RegisterJVPFunction(previous_fn)
pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFunctionCacheLimited(self):
@ -738,19 +738,19 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
with backprop.GradientTape() as tape:
self.assertFalse(tape_lib.should_record_backprop([c]))
self.assertEqual(
1, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertEqual(1,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
tape.watch(c)
self.assertEqual(
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertEqual(2,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertTrue(tape_lib.should_record_backprop([c]))
with tape_lib.stop_recording():
self.assertEqual(
0, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertEqual(0,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertFalse(tape_lib.should_record_backprop([c]))
d = c * 2.
self.assertEqual(
2, pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertEqual(2,
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
self.assertTrue(tape_lib.should_record_backprop([c]))
self.assertFalse(tape_lib.should_record_backprop([d]))
self.assertIsNone(acc.jvp(d))

View File

@ -24,7 +24,7 @@ from __future__ import print_function
import collections
import contextlib
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
class TangentInfo(
@ -54,8 +54,7 @@ def pack_tangents(tensors):
tangents: A flat list of Tensors. Best interpreted as a sequence to be
appended to `tensors`.
"""
return TangentInfo(
*pywrap_tensorflow.TFE_Py_PackJVPs(tensors))
return TangentInfo(*pywrap_tfe.TFE_Py_PackJVPs(tensors))
@contextlib.contextmanager
@ -73,7 +72,7 @@ def push_forwardprop_state():
None (used for its side effect).
"""
try:
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState()
pywrap_tfe.TFE_Py_ForwardAccumulatorPushState()
yield
finally:
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState()
pywrap_tfe.TFE_Py_ForwardAccumulatorPopState()

View File

@ -32,8 +32,9 @@ from six.moves import map
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python import _pywrap_utils
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
@ -1098,7 +1099,7 @@ class _TapeGradientFunctions(object):
forward_function.signature.name,
forward_outputs, forward_inputs, py_backward, None)
output_indices, output_tangents = (
pywrap_tensorflow.TFE_Py_PackJVPs(forward_outputs))
pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
output_tangents = [forward_wrapper_graph.capture(t)
for t in output_tangents]
return _ForwardWrapper(
@ -1732,7 +1733,7 @@ class ConcreteFunction(object):
"Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captured_inputs
possible_gradient_type = (
pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args))
pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args))
if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE
and executing_eagerly):
# No tape is watching; skip to running the function.
@ -2552,8 +2553,8 @@ class Function(object):
"""Computes the cache key given inputs and execution context."""
if self.input_signature is None:
inputs = (args, kwargs) if kwargs else args
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(
inputs, include_tensor_ranks_only)
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
include_tensor_ranks_only)
else:
del args, kwargs
assert not include_tensor_ranks_only

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import collections
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.util import compat
@ -68,7 +68,7 @@ def imperative_grad(tape,
raise ValueError(
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)
return pywrap_tensorflow.TFE_Py_TapeGradient(
return pywrap_tfe.TFE_Py_TapeGradient(
tape._tape, # pylint: disable=protected-access
target,
sources,

View File

@ -21,80 +21,80 @@ from __future__ import print_function
import collections
from tensorflow.core.framework import summary_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.util import compat
_MetricMethod = collections.namedtuple('MetricMethod', 'create delete get_cell')
_counter_methods = [
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewCounter0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter0),
create=pywrap_tfe.TFE_MonitoringNewCounter0,
delete=pywrap_tfe.TFE_MonitoringDeleteCounter0,
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter0),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewCounter1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter1),
create=pywrap_tfe.TFE_MonitoringNewCounter1,
delete=pywrap_tfe.TFE_MonitoringDeleteCounter1,
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter1),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewCounter2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteCounter2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellCounter2),
create=pywrap_tfe.TFE_MonitoringNewCounter2,
delete=pywrap_tfe.TFE_MonitoringDeleteCounter2,
get_cell=pywrap_tfe.TFE_MonitoringGetCellCounter2),
]
_int_gauge_methods = [
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge0),
create=pywrap_tfe.TFE_MonitoringNewIntGauge0,
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge0,
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge0),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge1),
create=pywrap_tfe.TFE_MonitoringNewIntGauge1,
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge1,
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge1),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewIntGauge2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteIntGauge2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellIntGauge2),
create=pywrap_tfe.TFE_MonitoringNewIntGauge2,
delete=pywrap_tfe.TFE_MonitoringDeleteIntGauge2,
get_cell=pywrap_tfe.TFE_MonitoringGetCellIntGauge2),
]
_string_gauge_methods = [
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge0),
create=pywrap_tfe.TFE_MonitoringNewStringGauge0,
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge0,
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge0),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge1),
create=pywrap_tfe.TFE_MonitoringNewStringGauge1,
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge1,
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge1),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewStringGauge2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteStringGauge2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellStringGauge2),
create=pywrap_tfe.TFE_MonitoringNewStringGauge2,
delete=pywrap_tfe.TFE_MonitoringDeleteStringGauge2,
get_cell=pywrap_tfe.TFE_MonitoringGetCellStringGauge2),
]
_bool_gauge_methods = [
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge0),
create=pywrap_tfe.TFE_MonitoringNewBoolGauge0,
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge0,
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge0),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge1),
create=pywrap_tfe.TFE_MonitoringNewBoolGauge1,
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge1,
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge1),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewBoolGauge2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteBoolGauge2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellBoolGauge2),
create=pywrap_tfe.TFE_MonitoringNewBoolGauge2,
delete=pywrap_tfe.TFE_MonitoringDeleteBoolGauge2,
get_cell=pywrap_tfe.TFE_MonitoringGetCellBoolGauge2),
]
_sampler_methods = [
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewSampler0,
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler0,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler0),
create=pywrap_tfe.TFE_MonitoringNewSampler0,
delete=pywrap_tfe.TFE_MonitoringDeleteSampler0,
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler0),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewSampler1,
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler1,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler1),
create=pywrap_tfe.TFE_MonitoringNewSampler1,
delete=pywrap_tfe.TFE_MonitoringDeleteSampler1,
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler1),
_MetricMethod(
create=pywrap_tensorflow.TFE_MonitoringNewSampler2,
delete=pywrap_tensorflow.TFE_MonitoringDeleteSampler2,
get_cell=pywrap_tensorflow.TFE_MonitoringGetCellSampler2),
create=pywrap_tfe.TFE_MonitoringNewSampler2,
delete=pywrap_tfe.TFE_MonitoringDeleteSampler2,
get_cell=pywrap_tfe.TFE_MonitoringGetCellSampler2),
]
@ -156,11 +156,11 @@ class CounterCell(object):
Args:
value: non-negative value.
"""
pywrap_tensorflow.TFE_MonitoringCounterCellIncrementBy(self._cell, value)
pywrap_tfe.TFE_MonitoringCounterCellIncrementBy(self._cell, value)
def value(self):
"""Retrieves the current value."""
return pywrap_tensorflow.TFE_MonitoringCounterCellValue(self._cell)
return pywrap_tfe.TFE_MonitoringCounterCellValue(self._cell)
class Counter(Metric):
@ -204,11 +204,11 @@ class IntGaugeCell(object):
Args:
value: integer value.
"""
pywrap_tensorflow.TFE_MonitoringIntGaugeCellSet(self._cell, value)
pywrap_tfe.TFE_MonitoringIntGaugeCellSet(self._cell, value)
def value(self):
"""Retrieves the current value."""
return pywrap_tensorflow.TFE_MonitoringIntGaugeCellValue(self._cell)
return pywrap_tfe.TFE_MonitoringIntGaugeCellValue(self._cell)
class IntGauge(Metric):
@ -252,13 +252,13 @@ class StringGaugeCell(object):
Args:
value: string value.
"""
pywrap_tensorflow.TFE_MonitoringStringGaugeCellSet(self._cell, value)
pywrap_tfe.TFE_MonitoringStringGaugeCellSet(self._cell, value)
def value(self):
"""Retrieves the current value."""
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
value = pywrap_tensorflow.TF_GetBuffer(buffer_).decode('utf-8')
pywrap_tfe.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_)
value = pywrap_tfe.TF_GetBuffer(buffer_).decode('utf-8')
return value
@ -303,11 +303,11 @@ class BoolGaugeCell(object):
Args:
value: bool value.
"""
pywrap_tensorflow.TFE_MonitoringBoolGaugeCellSet(self._cell, value)
pywrap_tfe.TFE_MonitoringBoolGaugeCellSet(self._cell, value)
def value(self):
"""Retrieves the current value."""
return pywrap_tensorflow.TFE_MonitoringBoolGaugeCellValue(self._cell)
return pywrap_tfe.TFE_MonitoringBoolGaugeCellValue(self._cell)
class BoolGauge(Metric):
@ -351,7 +351,7 @@ class SamplerCell(object):
Args:
value: float value.
"""
pywrap_tensorflow.TFE_MonitoringSamplerCellAdd(self._cell, value)
pywrap_tfe.TFE_MonitoringSamplerCellAdd(self._cell, value)
def value(self):
"""Retrieves the current distribution of samples.
@ -360,8 +360,8 @@ class SamplerCell(object):
A HistogramProto describing the distribution of samples.
"""
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_)
proto_data = pywrap_tfe.TF_GetBuffer(buffer_)
histogram_proto = summary_pb2.HistogramProto()
histogram_proto.ParseFromString(compat.as_bytes(proto_data))
return histogram_proto
@ -379,7 +379,7 @@ class Buckets(object):
self.buckets = buckets
def __del__(self):
pywrap_tensorflow.TFE_MonitoringDeleteBuckets(self.buckets)
pywrap_tfe.TFE_MonitoringDeleteBuckets(self.buckets)
class ExponentialBuckets(Buckets):
@ -399,8 +399,8 @@ class ExponentialBuckets(Buckets):
bucket_count: integer
"""
super(ExponentialBuckets, self).__init__(
pywrap_tensorflow.TFE_MonitoringNewExponentialBuckets(
scale, growth_factor, bucket_count))
pywrap_tfe.TFE_MonitoringNewExponentialBuckets(scale, growth_factor,
bucket_count))
class Sampler(Metric):

View File

@ -39,9 +39,9 @@ import os
import threading
from tensorflow.python import _pywrap_events_writer
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.framework import c_api_util
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
@ -74,8 +74,8 @@ def start():
raise ProfilerAlreadyRunningError('Another profiler is running.')
if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized()
_profiler = pywrap_tensorflow.TFE_NewProfiler()
if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler):
_profiler = pywrap_tfe.TFE_NewProfiler()
if not pywrap_tfe.TFE_ProfilerIsOk(_profiler):
logging.warning('Another profiler session is running which is probably '
'created by profiler server. Please avoid using profiler '
'server and profiler APIs at the same time.')
@ -100,11 +100,9 @@ def stop():
if context.default_execution_mode == context.EAGER_MODE:
context.context().executor.wait()
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ProfilerSerializeToString(
_profiler,
buffer_)
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
pywrap_tensorflow.TFE_DeleteProfiler(_profiler)
pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
result = pywrap_tfe.TF_GetBuffer(buffer_)
pywrap_tfe.TFE_DeleteProfiler(_profiler)
_profiler = None
_run_num += 1
return result
@ -159,7 +157,7 @@ def start_profiler_server(port):
"""
if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized()
pywrap_tensorflow.TFE_StartProfilerServer(port)
pywrap_tfe.TFE_StartProfilerServer(port)
class Profiler(object):

View File

@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import eager_util as c_api_util
from tensorflow.python.framework import errors
@ -46,7 +46,7 @@ def start_tracing(service_addr,
Raises:
UnavailableError: If no trace event is collected.
"""
if not pywrap_tensorflow.TFE_ProfilerClientStartTracing(
if not pywrap_tfe.TFE_ProfilerClientStartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts):
raise errors.UnavailableError(None, None, 'No trace event is collected.')
@ -71,7 +71,7 @@ def monitor(service_addr,
A string of monitoring output.
"""
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms,
monitoring_level,
display_timestamp, buffer_)
return pywrap_tensorflow.TF_GetBuffer(buffer_)
pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms,
monitoring_level, display_timestamp,
buffer_)
return pywrap_tfe.TF_GetBuffer(buffer_)

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import core
@ -54,14 +54,16 @@ class Tests(test.TestCase):
self.assertAllClose(
math_ops.matmul(a_2_by_2, b_2_by_2),
pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2,
b_2_by_2, "transpose_a", False, "transpose_b", False))
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_2_by_2,
b_2_by_2, "transpose_a", False,
"transpose_b", False))
self.assertAllClose(
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "MatMul", None, None, a_100_by_784,
b_100_by_784, "transpose_a", False, "transpose_b", True))
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_100_by_784,
b_100_by_784, "transpose_a", False,
"transpose_b", True))
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
@ -71,12 +73,14 @@ class Tests(test.TestCase):
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
m = resource_variable_ops.ResourceVariable(a_2_by_2)
x = pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a",
False, "transpose_b", False)
y = pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, a_2_by_2,
"transpose_a", False, "transpose_b", False)
x = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, m, m,
"transpose_a", False, "transpose_b",
False)
y = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_2_by_2,
a_2_by_2, "transpose_a", False,
"transpose_b", False)
self.assertAllEqual(x, y)
@ -89,9 +93,10 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape:
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
tape.watch(a_2_by_2)
z = pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2,
a_2_by_2, "transpose_a", False, "transpose_b", False)
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, a_2_by_2,
a_2_by_2, "transpose_a", False,
"transpose_b", False)
dz_dy = tape.gradient(z, [a_2_by_2])[0]
self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy())
@ -106,9 +111,10 @@ class Tests(test.TestCase):
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
m = resource_variable_ops.ResourceVariable(a_2_by_2)
tape.watch(m)
z = pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "MatMul", None, None, m, m,
"transpose_a", False, "transpose_b", False)
z = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, m, m,
"transpose_a", False, "transpose_b",
False)
dz_dy = tape.gradient(z, [m])[0]
self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy())
@ -125,9 +131,8 @@ class Tests(test.TestCase):
self.assertAllClose(
math_ops.add_n([a_2_by_2, b_2_by_2]),
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"AddN", None, None,
[a_2_by_2, b_2_by_2]))
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "AddN",
None, None, [a_2_by_2, b_2_by_2]))
# Tests homogeneous list op
@test_util.assert_no_new_tensors
@ -142,9 +147,9 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape:
tape.watch(a_2_by_2)
tape.watch(b_2_by_2)
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "AddN", None, None,
[a_2_by_2, b_2_by_2])
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"AddN", None, None,
[a_2_by_2, b_2_by_2])
z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
dz2_dy = tape.gradient(z2, [a_2_by_2])[0]
@ -162,9 +167,9 @@ class Tests(test.TestCase):
self.assertAllClose(
array_ops.identity_n([a_2_by_2, b_2_by_2]),
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"IdentityN", None, None,
[a_2_by_2, b_2_by_2]))
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"IdentityN", None, None,
[a_2_by_2, b_2_by_2]))
# Tests heterogeneous list op
@test_util.assert_no_new_tensors
@ -179,9 +184,9 @@ class Tests(test.TestCase):
with backprop.GradientTape(persistent=True) as tape:
tape.watch(a_2_by_2)
tape.watch(b_2_by_2)
z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
ctx._handle, ctx.device_name, "IdentityN", None, None,
[a_2_by_2, b_2_by_2])
z1 = pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"IdentityN", None, None,
[a_2_by_2, b_2_by_2])
z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
@ -201,19 +206,18 @@ class Tests(test.TestCase):
# Not enough base params
with self.assertRaisesRegexp(ValueError,
"at least 5 items in the input tuple"):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
"Identity")
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity")
# Not enough inputs
with self.assertRaisesRegexp(ValueError,
"Expected to be at least 6, was 5"):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx_handle,
"Identity", None, [])
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, "Identity",
None, [])
# Bad type
with self.assertRaisesRegexp(TypeError, "expected a string for op_name"):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
ctx_handle, None, [], a_2_by_2)
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, ctx_handle,
None, [], a_2_by_2)
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
@ -225,9 +229,9 @@ class Tests(test.TestCase):
ctx_handle = ctx._handle
with self.assertRaises(core._FallbackException):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
"Split", None, None, split_dim,
value, "num_split", -1)
pywrap_tfe.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Split",
None, None, split_dim, value,
"num_split", -1)
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
@ -266,10 +270,9 @@ class Tests(test.TestCase):
ctx = context.context()
ctx.ensure_initialized()
with self.assertRaises(core._FallbackException):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name,
"MatMul", None, None, m, m,
"transpose_a", False,
"transpose_b", False)
pywrap_tfe.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, "MatMul",
None, None, m, m, "transpose_a", False,
"transpose_b", False)
def testOpDefDefaultType(self):
im = np.random.randint(

View File

@ -22,7 +22,7 @@ import copy
from absl import logging
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.eager import context
@ -127,7 +127,7 @@ def connect_to_cluster(cluster_spec_or_resolver,
# Automatically add local job, if not part of the cluster spec.
if job_name not in cluster_spec.jobs:
local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
job_def = cluster_def.job.add()
job_def.name = job_name
# TODO(fishx): Update this to make sure remote worker has valid ip address

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import contextlib
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.util.lazy_loader import LazyLoader
# There is a circular dependency between this, ops.py, and
@ -39,24 +39,23 @@ class Tape(object):
self._tape = tape
def watched_variables(self):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
def push_new_tape(persistent=False, watch_accessed_variables=True):
"""Pushes a new tape onto the tape stack."""
tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
watch_accessed_variables)
tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
return Tape(tape)
def push_tape(tape):
"""Pushes an existing tape onto the tape stack."""
pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
def watch(tape, tensor):
"""Marks this tensor to be watched by the given tape."""
pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
def watch_variable(tape, variable):
@ -68,7 +67,7 @@ def watch_variable(tape, variable):
else:
variables = strategy.experimental_local_results(variable)
for var in variables:
pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
def variable_accessed(variable):
@ -84,7 +83,7 @@ def variable_accessed(variable):
else:
variables = strategy.experimental_local_results(variable)
for var in variables:
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
def variables_accessed(variables):
@ -107,25 +106,25 @@ def variables_accessed(variables):
accessed.extend(strategy.experimental_local_results(variable))
for var in accessed:
pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var)
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
def pop_tape(tape):
"""Pops the given tape in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
@contextlib.contextmanager
def stop_recording():
"""Stop all gradient recording (backprop and forwardprop)."""
is_stopped = pywrap_tensorflow.TFE_Py_TapeSetIsStopped()
is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped()
try:
if not is_stopped:
pywrap_tensorflow.TFE_Py_TapeSetStopOnThread()
pywrap_tfe.TFE_Py_TapeSetStopOnThread()
yield
finally:
if not is_stopped:
pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread()
pywrap_tfe.TFE_Py_TapeSetRestartOnThread()
def should_record_backprop(tensors):
@ -139,22 +138,23 @@ def should_record_backprop(tensors):
Returns:
Boolean, whether any tape watches any of `tensors`.
"""
return pywrap_tensorflow.TFE_Py_TapeSetShouldRecordBackprop(tensors)
return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors)
def record_operation(op_type, output_tensors, input_tensors, backward_function,
forward_function=None):
"""Records the operation on all tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetRecordOperation(
op_type, output_tensors, input_tensors, backward_function,
forward_function)
pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors,
input_tensors, backward_function,
forward_function)
def record_operation_backprop_only(op_type, output_tensors, input_tensors,
backward_function):
"""Records the operation on all backward tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationBackprop(
op_type, output_tensors, input_tensors, backward_function)
pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors,
input_tensors,
backward_function)
def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
@ -174,16 +174,16 @@ def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
Typically these will have come from TFE_Py_PackForwardGradients. May be
None or an empty sequence if there are no JVP outputs from the operation.
"""
pywrap_tensorflow.TFE_Py_TapeSetRecordOperationForwardprop(
pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop(
op_type, output_tensors, input_tensors, backward_function,
forwardprop_output_indices)
def delete_trace(tensor_id):
"""Deletes traces for this Tensor from all tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id)
pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id)
def could_possibly_record():
"""Returns True if any tape is active."""
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()

View File

@ -26,7 +26,7 @@ import unittest
import numpy as np
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import test
@ -435,14 +435,14 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
r = pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
def testEmptyTensorList(self):
a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0)
a = pywrap_tfe.TFE_Py_TensorShapeSlice([], 0)
self.assertTrue(isinstance(a, ops.EagerTensor))
self.assertEqual(0, a.numpy().size)
@ -452,12 +452,12 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(
TypeError,
r"Expected a list of EagerTensors but element 1 has type \"str\""):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
pywrap_tfe.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
with self.assertRaisesRegexp(
TypeError,
r"Expected a list of EagerTensors but element 0 has type \"int\""):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0)
pywrap_tfe.TFE_Py_TensorShapeSlice([2, t1], 0)
def testTensorListNotList(self):
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
@ -465,7 +465,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(
TypeError,
r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
pywrap_tfe.TFE_Py_TensorShapeSlice(t1, -2)
def testNegativeSliceDim(self):
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
@ -473,7 +473,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(
ValueError,
r"Slice dimension must be non-negative. Got -2"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2)
pywrap_tfe.TFE_Py_TensorShapeSlice([t1], -2)
def testUnicode(self):
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
@ -493,31 +493,31 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
IndexError,
r"Slice dimension \(2\) must be smaller than rank of all tensors, "
"but tensor at index 0 has rank 2"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2)
pywrap_tfe.TFE_Py_TensorShapeSlice([t1], 2)
with self.assertRaisesRegexp(
IndexError,
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
"but tensor at index 0 has rank 1"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1)
pywrap_tfe.TFE_Py_TensorShapeSlice([t2], 1)
with self.assertRaisesRegexp(
IndexError,
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
"but tensor at index 1 has rank 1"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1)
pywrap_tfe.TFE_Py_TensorShapeSlice([t1, t2], 1)
with self.assertRaisesRegexp(
IndexError,
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
"but tensor at index 0 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0)
pywrap_tfe.TFE_Py_TensorShapeSlice([t3], 0)
with self.assertRaisesRegexp(
IndexError,
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
"but tensor at index 2 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
pywrap_tfe.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testTensorDir(self):

View File

@ -36,7 +36,12 @@ from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import config_pb2
# pywrap_tensorflow must be imported first to avoid profobuf issues.
# (b/143110113)
# pylint: disable=invalid-import-order,g-bad-import-order
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python import pywrap_tfe as c_api_new
# pylint: enable=invalid-import-order,g-bad-import-order
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.eager import core
@ -249,7 +254,7 @@ def register_dense_tensor_like_type(tensor_type):
def uid():
"""A unique (within this program execution) integer."""
return c_api.TFE_Py_UID()
return c_api_new.TFE_Py_UID()
def numpy_text(tensor, is_repr=False):
@ -1135,7 +1140,7 @@ class _EagerTensorBase(Tensor):
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
# registers it with the current module.
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
EagerTensor = c_api_new.TFE_Py_InitEagerTensor(_EagerTensorBase)
register_dense_tensor_like_type(Tensor)

View File

@ -789,8 +789,7 @@ void GenEagerPythonOp::AddEagerFastPathExecute() {
strings::StrAppend(&result_, " try:\n");
strings::StrAppend(
&result_, " ",
"_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
&result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
WordWrap(strings::StrCat(" "),
strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
"\n");
@ -1000,7 +999,7 @@ This file is MACHINE GENERATED! Do not edit.
import collections
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -115,8 +116,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
non_neg_concat_dim = (
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
# All inputs are guaranteed to be EagerTensors in eager mode
sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
non_neg_concat_dim)
sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
non_neg_concat_dim)
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else:
if constant_op.is_constant(concat_dim):

View File

@ -26,7 +26,7 @@ import sys
from absl import logging
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -45,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export
# Register printing to the cell output if we are in a Colab or Jupyter Notebook.
try:
get_ipython() # Exists in an ipython env like Jupyter or Colab
pywrap_tensorflow.TFE_Py_EnableInteractivePythonLogging()
pywrap_tfe.TFE_Py_EnableInteractivePythonLogging()
except NameError:
pass

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
using tensorflow::uint64;
@ -233,7 +234,50 @@ _COPY_TYPEMAPS(unsigned int, mode_t);
%define override %enddef
#endif
// This was originally included in pywrap_tfe.i, but is used by tf_session.i
%include "tensorflow/c/tf_status.h"
%include "tensorflow/c/tf_datatype.h"
%typemap(in) (const void* proto) {
char* c_string;
Py_ssize_t py_size;
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
$1 = static_cast<void*>(c_string);
}
%typemap(in) int64_t {
$1 = PyLong_AsLongLong($input);
}
%typemap(out) TF_DataType {
$result = PyInt_FromLong($1);
}
%typemap(out) int64_t {
$result = PyInt_FromLong($1);
}
%typemap(out) TF_AttrType {
$result = PyInt_FromLong($1);
}
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
tmp = 0;
$1 = &tmp;
}
%typemap(argout) unsigned char* is_list {
if (*$1 == 1) {
PyObject* list = PyList_New(1);
PyList_SetItem(list, 0, $result);
$result = list;
}
}
// Typemaps to automatically raise a Python exception from bad output TF_Status.
// TODO(b/77295559): expand this to all TF_Status* output params and deprecate

View File

@ -1,515 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
%}
%include "tensorflow/c/tf_datatype.h"
%include "tensorflow/c/tf_status.h"
%ignoreall;
%rename("%s") TF_SetXlaEnableLazyCompilation;
%rename("%s") TF_SetTfXlaCpuGlobalJit;
%rename("%s") TF_SetXlaAutoJitMode;
%rename("%s") TF_SetXlaConstantFoldingDisabled;
%rename("%s") TF_GetXlaConstantFoldingDisabled;
%rename("%s") TF_SetXlaMinClusterSize;
%rename("%s") TFE_NewContext;
%rename("%s") TFE_DeleteContext;
%rename("%s") TFE_ContextListDevices;
%rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_ContextRemoveFunction;
%rename("%s") TFE_ContextHasFunction;
%rename("%s") TFE_ContextEnableRunMetadata;
%rename("%s") TFE_ContextDisableRunMetadata;
%rename("%s") TFE_ContextEnableGraphCollection;
%rename("%s") TFE_ContextDisableGraphCollection;
%rename("%s") TFE_ContextExportRunMetadata;
%rename("%s") TFE_ContextClearCaches;
%rename("%s") TFE_ContextGetDevicePlacementPolicy;
%rename("%s") TFE_ContextGetMirroringPolicy;
%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
%rename("%s") TFE_ContextSetThreadLocalMirroringPolicy;
%rename("%s") TFE_ContextSetServerDef;
%rename("%s") TFE_ContextUpdateServerDef;
%rename("%s") TFE_ContextCheckAlive;
%rename("%s") TFE_NewExecutor;
%rename("%s") TFE_DeleteExecutor;
%rename("%s") TFE_ExecutorIsAsync;
%rename("%s") TFE_ExecutorWaitForAllPendingNodes;
%rename("%s") TFE_ExecutorClearError;
%rename("%s") TFE_ContextSetExecutorForThread;
%rename("%s") TFE_ContextGetExecutorForThread;
%rename("%s") TFE_NewProfiler;
%rename("%s") TFE_ProfilerIsOk;
%rename("%s") TFE_DeleteProfiler;
%rename("%s") TFE_ProfilerSerializeToString;
%rename("%s") TFE_StartProfilerServer;
%rename("%s") TFE_ProfilerClientStartTracing;
%rename("%s") TFE_ProfilerClientMonitor;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_SetEagerTensorProfiler;
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_RegisterJVPFunction;
%rename("%s") TFE_Py_RegisterGradientFunction;
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_ExecuteCancelable;
%rename("%s") TFE_Py_FastPathExecute;
%rename("%s") TFE_Py_RecordGradient;
%rename("%s") TFE_Py_UID;
%rename("%s") TFE_Py_TapeSetNew;
%rename("%s") TFE_Py_TapeSetAdd;
%rename("%s") TFE_Py_TapeSetRemove;
%rename("%s") TFE_Py_TapeSetStopOnThread;
%rename("%s") TFE_Py_TapeSetRestartOnThread;
%rename("%s") TFE_Py_TapeSetIsStopped;
%rename("%s") TFE_Py_TapeSetIsEmpty;
%rename("%s") TFE_Py_TapeSetShouldRecordBackprop;
%rename("%s") TFE_Py_TapeSetPossibleGradientTypes;
%rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation;
%rename("%s") TFE_Py_TapeSetRecordOperationBackprop;
%rename("%s") TFE_Py_TapeSetRecordOperationForwardprop;
%rename("%s") TFE_Py_TapeGradient;
%rename("%s") TFE_Py_TapeVariableAccessed;
%rename("%s") TFE_Py_TapeWatch;
%rename("%s") TFE_Py_TapeWatchVariable;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_Py_ForwardAccumulatorNew;
%rename("%s") TFE_Py_ForwardAccumulatorSetAdd;
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
%rename("%s") TFE_Py_ForwardAccumulatorPushState;
%rename("%s") TFE_Py_ForwardAccumulatorPopState;
%rename("%s") TFE_Py_PackJVPs;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
%rename("%s") TFE_ContextOptionsSetMirroringPolicy;
%rename("%s") TFE_ContextOptionsSetAsync;
%rename("%s") TFE_ContextOptionsSetLazyRemoteInputsCopy;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;
%rename("%s") TFE_Py_EnableInteractivePythonLogging;
%rename("%s") TFE_Py_SetEagerContext;
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
%rename("%s") TFE_Py_RegisterVSpace;
%rename("%s") TFE_Py_EncodeArg;
%rename("%s") TFE_EnableCollectiveOps;
%rename("%s") TF_ListPhysicalDevices;
%rename("%s") TF_PickUnusedPortOrDie;
%rename("%s") TFE_MonitoringCounterCellIncrementBy;
%rename("%s") TFE_MonitoringCounterCellValue;
%rename("%s") TFE_MonitoringNewCounter0;
%rename("%s") TFE_MonitoringDeleteCounter0;
%rename("%s") TFE_MonitoringGetCellCounter0;
%rename("%s") TFE_MonitoringNewCounter1;
%rename("%s") TFE_MonitoringDeleteCounter1;
%rename("%s") TFE_MonitoringGetCellCounter1;
%rename("%s") TFE_MonitoringNewCounter2;
%rename("%s") TFE_MonitoringDeleteCounter2;
%rename("%s") TFE_MonitoringGetCellCounter2;
%rename("%s") TFE_MonitoringIntGaugeCellSet;
%rename("%s") TFE_MonitoringIntGaugeCellValue;
%rename("%s") TFE_MonitoringNewIntGauge0;
%rename("%s") TFE_MonitoringDeleteIntGauge0;
%rename("%s") TFE_MonitoringGetCellIntGauge0;
%rename("%s") TFE_MonitoringNewIntGauge1;
%rename("%s") TFE_MonitoringDeleteIntGauge1;
%rename("%s") TFE_MonitoringGetCellIntGauge1;
%rename("%s") TFE_MonitoringNewIntGauge2;
%rename("%s") TFE_MonitoringDeleteIntGauge2;
%rename("%s") TFE_MonitoringGetCellIntGauge2;
%rename("%s") TFE_MonitoringStringGaugeCellSet;
%rename("%s") TFE_MonitoringStringGaugeCellValue;
%rename("%s") TFE_MonitoringNewStringGauge0;
%rename("%s") TFE_MonitoringDeleteStringGauge0;
%rename("%s") TFE_MonitoringGetCellStringGauge0;
%rename("%s") TFE_MonitoringNewStringGauge1;
%rename("%s") TFE_MonitoringDeleteStringGauge1;
%rename("%s") TFE_MonitoringGetCellStringGauge1;
%rename("%s") TFE_MonitoringNewStringGauge2;
%rename("%s") TFE_MonitoringDeleteStringGauge2;
%rename("%s") TFE_MonitoringGetCellStringGauge2;
%rename("%s") TFE_MonitoringBoolGaugeCellSet;
%rename("%s") TFE_MonitoringBoolGaugeCellValue;
%rename("%s") TFE_MonitoringNewBoolGauge0;
%rename("%s") TFE_MonitoringDeleteBoolGauge0;
%rename("%s") TFE_MonitoringGetCellBoolGauge0;
%rename("%s") TFE_MonitoringNewBoolGauge1;
%rename("%s") TFE_MonitoringDeleteBoolGauge1;
%rename("%s") TFE_MonitoringGetCellBoolGauge1;
%rename("%s") TFE_MonitoringNewBoolGauge2;
%rename("%s") TFE_MonitoringDeleteBoolGauge2;
%rename("%s") TFE_MonitoringGetCellBoolGauge2;
%rename("%s") TFE_MonitoringSamplerCellAdd;
%rename("%s") TFE_MonitoringSamplerCellValue;
%rename("%s") TFE_MonitoringNewExponentialBuckets;
%rename("%s") TFE_MonitoringDeleteBuckets;
%rename("%s") TFE_MonitoringNewSampler0;
%rename("%s") TFE_MonitoringDeleteSampler0;
%rename("%s") TFE_MonitoringGetCellSampler0;
%rename("%s") TFE_MonitoringNewSampler1;
%rename("%s") TFE_MonitoringDeleteSampler1;
%rename("%s") TFE_MonitoringGetCellSampler1;
%rename("%s") TFE_MonitoringNewSampler2;
%rename("%s") TFE_MonitoringDeleteSampler2;
%rename("%s") TFE_MonitoringGetCellSampler2;
%rename("%s") TFE_NewCancellationManager;
%rename("%s") TFE_CancellationManagerIsCancelled;
%rename("%s") TFE_CancellationManagerStartCancel;
%rename("%s") TFE_DeleteCancellationManager;
%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints;
%rename("%s") TFE_ClearScalarCache;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/util/util.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/common_runtime/device_factory.h"
static PyObject* TF_ListPhysicalDevices(TF_Status* status) {
std::vector<string> devices;
tensorflow::Status s = tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
tensorflow::Set_TF_Status_from_Status(status, s);
if (!s.ok()) {
Py_RETURN_NONE;
};
PyObject* result = PyList_New(devices.size());
int i = 0;
for (auto& dev : devices) {
PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
PyList_SetItem(result, i, dev_obj);
++i;
}
return result;
}
%}
static PyObject* TF_ListPhysicalDevices(TF_Status* status);
%{
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
static PyObject* TFE_ClearScalarCache() {
tensorflow::TFE_TensorHandleCache::Get()->Clear();
Py_RETURN_NONE;
}
%}
static PyObject* TFE_ClearScalarCache();
%typemap(in) (const void* proto) {
char* c_string;
Py_ssize_t py_size;
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
$1 = static_cast<void*>(c_string);
}
%typemap(in) int64_t {
$1 = PyLong_AsLongLong($input);
}
%typemap(out) TF_DataType {
$result = PyInt_FromLong($1);
}
%typemap(out) int64_t {
$result = PyInt_FromLong($1);
}
%typemap(out) TF_AttrType {
$result = PyInt_FromLong($1);
}
%typemap(in, numinputs=0) unsigned char* is_list (unsigned char tmp) {
tmp = 0;
$1 = &tmp;
}
%typemap(argout) unsigned char* is_list {
if (*$1 == 1) {
PyObject* list = PyList_New(1);
PyList_SetItem(list, 0, $result);
$result = list;
}
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* serialized_function_def {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* device_name {
if ($input == Py_None) {
$1 = nullptr;
} else {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* op_name {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* name {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* description {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* label {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* label1 {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* label2 {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* service_addr {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* logdir {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
// For const parameters in a function, SWIG pretty much ignores the const.
// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
// Hence the 'const_cast'.
%typemap(in) const char* worker_list {
$1 = const_cast<char*>(TFE_GetPythonString($input));
}
%typemap(in) (TFE_Context*) {
$1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr);
}
%typemap(out) (TFE_Context*) {
// When the TFE_Context* returned is a nullptr, we expect the status is not
// OK. This will raise an error (happens in another typemap).
if ($1 != nullptr) {
$result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
}
}
%rename("%s") TFE_ContextDevicePlacementPolicy;
%rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT;
%rename("%s") TFE_DEVICE_PLACEMENT_WARN;
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT;
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32;
%rename("%s") TFE_ContextMirroringPolicy;
%rename("%s") TFE_MIRRORING_NONE;
%rename("%s") TFE_MIRRORING_ALL;
%include "tensorflow/c/eager/c_api.h"
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
$1 = &temp;
if ($input != Py_None) {
if (!PyList_Check($input)) {
SWIG_exception_fail(SWIG_TypeError,
"must provide a list of Tensors as inputs");
}
Py_ssize_t len = PyList_Size($input);
$1->resize(len);
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* elem = PyList_GetItem($input, i);
if (!elem) {
SWIG_fail;
}
if (EagerTensor_CheckExact(elem)) {
(*$1)[i] = EagerTensor_Handle(elem);
} else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
// Use equivalent of object.__getattribute__ to get the underlying
// tf wrapped EagerTensor (if there is one).
tensorflow::Safe_PyObjectPtr tf_should_use_attr(
#if PY_MAJOR_VERSION < 3
PyString_InternFromString("_tf_should_use_wrapped_value")
#else
PyUnicode_InternFromString("_tf_should_use_wrapped_value")
#endif
);
tensorflow::Safe_PyObjectPtr value_attr(
PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
if (value_attr) {
// This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
(*$1)[i] = EagerTensor_Handle(value_attr.get());
} else {
// This is a subclass of EagerTensor that we don't support.
PyErr_Clear();
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::StrCat(
"Saw an object that is an instance of a strict subclass of "
"EagerTensor, which is not supported. Item ",
i, " is type: ", elem->ob_type->tp_name)
.c_str());
}
} else if (tensorflow::swig::IsTensor(elem)) {
// If it isnt an EagerTensor, but is still a Tensor, it must be a graph
// tensor.
tensorflow::Safe_PyObjectPtr name_attr(
PyObject_GetAttrString(elem, "name"));
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::StrCat(
"An op outside of the function building code is being passed\n"
"a \"Graph\" tensor. It is possible to have Graph tensors\n"
"leak out of the function building context by including a\n"
"tf.init_scope in your function building code.\n"
"For example, the following function will fail:\n",
" @tf.function\n",
" def has_init_scope():\n",
" my_constant = tf.constant(1.)\n",
" with tf.init_scope():\n",
" added = my_constant * 2\n",
"The graph tensor has name: ",
name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>"
).c_str());
} else {
SWIG_exception_fail(
SWIG_TypeError,
tensorflow::strings::StrCat(
"provided list of inputs contains objects other "
"than 'EagerTensor'. Item ",
i, " is type: ", elem->ob_type->tp_name).c_str());
}
}
}
}
// Temporary for the argout
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp) {
if (!PyInt_Check($input)) {
SWIG_exception_fail(SWIG_TypeError,
"expected an integer value (size of the number of "
"outputs of the operation)");
}
$1 = &temp;
long sz = PyInt_AsLong($input);
if (sz > 0) {
$1->resize(PyInt_AsLong($input), nullptr);
}
}
// Create new Status object.
%typemap(in, numinputs=0) TF_Status *out_status {
$1 = GetStatus();
}
%typemap(freearg) (TF_Status* out_status) {
ReturnStatus($1);
}
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) {
SWIG_fail;
} else {
int num_outputs = $1->size();
Py_CLEAR($result);
$result = PyList_New(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
PyObject *output;
output = EagerTensorFromHandle($1->at(i));
PyList_SetItem($result, i, output);
}
}
}
// SWIG usually unwraps the tuple that the native Python/C interface generates.
// Since we wanted to have a function with a variable length of arguments, we
// used the native Python/C interface directly (which by default supports
// passing all arguments as a tuple).
%native(TFE_Py_FastPathExecute) TFE_Py_FastPathExecute_C;
%include "tensorflow/python/eager/pywrap_tfe.h"
%include "tensorflow/c/c_api_experimental.h"
%include "tensorflow/c/eager/c_api_experimental.h"
// Clear all typemaps.
%typemap(out) TF_DataType;
%typemap(in) int64_t;
%typemap(out) int64_t;
%typemap(out) TF_AttrType;
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(argout) unsigned char* is_list;
%typemap(in) const char* description;
%typemap(in) const char* label1;
%typemap(in) const char* label2;
%typemap(in) (TFE_Context*);
%typemap(out) (TFE_Context*);
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(freearg) (TF_Status* out_status);
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
%typemap(in) (const void* proto);
%unignoreall

View File

@ -0,0 +1,29 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python module for TFE ops and functions exported by pybind11.
This module is created because we are splitting out eager bindings from
pywrap_tensorflow. This is causing some issues where Graphs are not properly
initialized when running eager code. Once the graph architecture has been
removed from pywrap_tensorflow as well, we can remove this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import
from tensorflow.python import pywrap_tensorflow
from tensorflow.python._pywrap_tfe import *

View File

@ -17,8 +17,6 @@ limitations under the License.
* The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */
%include "tensorflow/python/pywrap_tfe.i"
%include "tensorflow/python/client/tf_session.i"
%include "tensorflow/python/lib/io/file_io.i"

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
*perftools*gputools*
*tf_*
*TF_*
*Eager*
*TFE_*
*nsync_*
*stream_executor*

View File

@ -4,6 +4,7 @@ tensorflow {
*toco*;
*perftools*gputools*;
*TF_*;
*Eager*;
*TFE_*;
*nsync_*;
*stream_executor*;

View File

@ -1,4 +1,4 @@
[cpp_python_util] # util
[cpp_python_util] # util tfe
tensorflow::swig::IsSequence
tensorflow::swig::IsSequenceOrComposite
tensorflow::swig::IsCompositeTensor
@ -17,6 +17,7 @@ tensorflow::swig::IsSequenceForData
tensorflow::swig::FlattenForData
tensorflow::swig::AssertSameStructureForData
tensorflow::swig::RegisterType
tensorflow::swig::IsEagerTensorSlow
[util_port] # util_port
tensorflow::IsGoogleCudaEnabled
@ -74,11 +75,12 @@ tensorflow::Status::code
tensorflow::Status::error_message
tensorflow::Status::ok()
[core_cpu_impl] # device_lib
[core_cpu_impl] # device_lib tfe
tensorflow::Device::attributes
tensorflow::DeviceFactory::AddDevices
tensorflow::SessionOptions::SessionOptions
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
tensorflow::DeviceFactory::ListAllPhysicalDevices
[protos_all] # device_lib, dtypes
tensorflow::DataType_IsValid
@ -123,3 +125,67 @@ tensorflow::make_safe
[python_op_gen] # python_op_gen
tensorflow::GetPythonWrappers
[pywrap_tfe_lib] # tfe
tensorflow::TFE_TensorHandleCache
tensorflow::TFE_TensorHandleCache::Clear
EagerTensor_CheckExact
EagerTensorFromHandle
EagerTensor_Handle
TFE_Py_ExecuteCancelable
TFE_Py_RegisterExceptionClass
TFE_Py_RegisterVSpace
TFE_Py_RegisterFallbackExceptionClass
TFE_Py_RegisterGradientFunction
TFE_Py_RegisterJVPFunction
TFE_GetPythonString
TFE_Py_UID
TFE_DeleteContextCapsule
TFE_Py_InitEagerTensor
TFE_Py_SetEagerTensorProfiler
TFE_Py_TapeSetNew
TFE_Py_TapeSetRemove
TFE_Py_TapeSetAdd
TFE_Py_TapeSetIsEmpty
TFE_Py_TapeSetShouldRecordBackprop
TFE_Py_TapeSetPossibleGradientTypes
TFE_Py_TapeWatch
TFE_Py_TapeSetDeleteTrace
TFE_Py_TapeSetStopOnThread
TFE_Py_TapeSetRestartOnThread
TFE_Py_TapeSetIsStopped
TFE_Py_TapeSetRecordOperation
TFE_Py_TapeSetRecordOperationBackprop
TFE_Py_TapeSetRecordOperationForwardprop
TFE_Py_TapeVariableAccessed
TFE_Py_TapeWatchVariable
TFE_Py_TapeGradient
TFE_Py_FastPathExecute_C
TFE_Py_RecordGradient
TFE_Py_TapeWatchedVariables
TFE_Py_ForwardAccumulatorNew
TFE_Py_ForwardAccumulatorSetAdd
TFE_Py_ForwardAccumulatorSetRemove
TFE_Py_ForwardAccumulatorWatch
TFE_Py_ForwardAccumulatorJVP
TFE_Py_ForwardAccumulatorPushState
TFE_Py_ForwardAccumulatorPopState
TFE_Py_PackJVPs
TFE_Py_TensorShapeSlice
TFE_Py_TensorShapeOnDevice
TFE_Py_EncodeArg
TFE_Py_EnableInteractivePythonLogging
TFE_Py_SetEagerContext
[eager_executor] # tfe
tensorflow::EagerExecutor::~EagerExecutor
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
[profiler_session] # tfe
tensorflow::ProfilerSession::~ProfilerSession
[tf_status_helper] # tfe
tensorflow::Set_TF_Status_from_Status
[context] # tfe
tensorflow::EagerContext::WaitForAndCloseRemoteContexts