Introduce Python-only extensions to the C API
Implements an incomplete version of Operation._add_control_input() using a new extension to make sure the plumbing works. This also adds header guards to c_api_internal.h, which were missing. For some reason the missing guards caused problems in the cmake build even though there doesn't appear to be any #include cycles. PiperOrigin-RevId: 161705859
This commit is contained in:
parent
4f54336348
commit
45a58d378e
tensorflow
c
contrib/cmake
python
@ -144,6 +144,19 @@ tf_custom_op_library(
|
||||
srcs = ["test_op.cc"],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python API target
|
||||
|
||||
tf_cuda_library(
|
||||
name = "python_api",
|
||||
srcs = ["python_api.cc"],
|
||||
hdrs = ["python_api.h"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets.
|
||||
|
||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_C_API_INTERNAL_H_
|
||||
#define TENSORFLOW_C_C_API_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#include <vector>
|
||||
@ -125,3 +128,5 @@ class TensorCApi {
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
|
||||
|
28
tensorflow/c/python_api.cc
Normal file
28
tensorflow/c/python_api.cc
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/python_api.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
|
||||
// TODO(skyewm): make sure cycles are prevented
|
||||
mutex_lock l(graph->mu);
|
||||
graph->graph.AddControlEdge(&input->node, &op->node);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
30
tensorflow/c/python_api.h
Normal file
30
tensorflow/c/python_api.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
// These functions can be removed without notice. They exist to facilitate some
|
||||
// refactoring of graph construction code in the Python API.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
|
@ -26,3 +26,9 @@ set(tf_c_srcs
|
||||
|
||||
add_library(tf_c OBJECT ${tf_c_srcs})
|
||||
add_dependencies(tf_c tf_cc_framework tf_core_lib tf_protos_cc)
|
||||
|
||||
add_library(tf_c_python_api OBJECT
|
||||
"${tensorflow_source_dir}/tensorflow/c/python_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/python_api.h"
|
||||
)
|
||||
add_dependencies(tf_c_python_api tf_c tf_cc_framework tf_core_lib tf_protos_cc)
|
||||
|
@ -761,6 +761,7 @@ if(WIN32)
|
||||
add_library(pywrap_tensorflow_internal_static STATIC
|
||||
${pywrap_tensorflow_internal_src}
|
||||
$<TARGET_OBJECTS:tf_c>
|
||||
$<TARGET_OBJECTS:tf_c_python_api>
|
||||
$<TARGET_OBJECTS:tf_core_lib>
|
||||
$<TARGET_OBJECTS:tf_core_cpu>
|
||||
$<TARGET_OBJECTS:tf_core_framework>
|
||||
@ -809,6 +810,7 @@ endif(WIN32)
|
||||
add_library(pywrap_tensorflow_internal SHARED
|
||||
${pywrap_tensorflow_internal_src}
|
||||
$<TARGET_OBJECTS:tf_c>
|
||||
$<TARGET_OBJECTS:tf_c_python_api>
|
||||
$<TARGET_OBJECTS:tf_core_lib>
|
||||
$<TARGET_OBJECTS:tf_core_cpu>
|
||||
$<TARGET_OBJECTS:tf_core_framework>
|
||||
|
@ -2787,6 +2787,7 @@ tf_py_wrap_cc(
|
||||
":tf_session_helper",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/c:python_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
%{
|
||||
|
||||
#include "tensorflow/c/python_api.h"
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
#include "tensorflow/core/framework/session_state.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -269,6 +270,7 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
|
||||
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
|
||||
|
||||
%include "tensorflow/c/c_api.h"
|
||||
%include "tensorflow/c/python_api.h"
|
||||
|
||||
%ignoreall
|
||||
%insert("python") %{
|
||||
|
@ -1527,9 +1527,10 @@ class Operation(object):
|
||||
TypeError: if op is not an Operation.
|
||||
ValueError: if op is from a different graph.
|
||||
"""
|
||||
assert not self._graph._c_graph, ( # pylint: disable=protected-access
|
||||
"Operation._add_control_input doesn't work with C API")
|
||||
self._add_control_inputs([op])
|
||||
if _USE_C_API:
|
||||
c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
|
||||
else:
|
||||
self._add_control_inputs([op])
|
||||
|
||||
# Methods below are used when building the NodeDef and Graph proto.
|
||||
def _recompute_node_def(self):
|
||||
|
@ -384,6 +384,15 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
self.assertIsInstance(x, dtypes.DType)
|
||||
self.assertEqual([dtypes.string, dtypes.double], l)
|
||||
|
||||
# TODO(skyewm): test adding cycles, other error cases
|
||||
@test_util.enable_c_api
|
||||
def testAddControlInput(self):
|
||||
with ops.Graph().as_default():
|
||||
x = constant_op.constant(1).op
|
||||
y = constant_op.constant(2).op
|
||||
y._add_control_input(x) # pylint: disable=protected-access
|
||||
self.assertEqual(y.control_inputs, [x])
|
||||
|
||||
|
||||
class CreateOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@ -1062,29 +1071,26 @@ class ComparisonTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.enable_c_api
|
||||
def testBasic(self):
|
||||
ops._USE_C_API = True # pylint: disable=protected-access
|
||||
try:
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# Creating unregistered ops with _apply_op() doesn't work with the C API
|
||||
# TODO(skyewm): address this more consistently. Possible solutions are
|
||||
# to use registered ops in all tests, create a way to register ops in
|
||||
# Python tests, or conditionally disable the op registration check in
|
||||
# the C API.
|
||||
a = constant_op.constant(1.0)
|
||||
b = constant_op.constant(1.0)
|
||||
with g.control_dependencies([a]):
|
||||
c = constant_op.constant(1.0)
|
||||
d = array_ops.identity(b)
|
||||
e = array_ops.identity(c)
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# Creating unregistered ops with _apply_op() doesn't work with the C API
|
||||
# TODO(skyewm): address this more consistently. Possible solutions are
|
||||
# to use registered ops in all tests, create a way to register ops in
|
||||
# Python tests, or conditionally disable the op registration check in
|
||||
# the C API.
|
||||
a = constant_op.constant(1.0)
|
||||
b = constant_op.constant(1.0)
|
||||
with g.control_dependencies([a]):
|
||||
c = constant_op.constant(1.0)
|
||||
d = array_ops.identity(b)
|
||||
e = array_ops.identity(c)
|
||||
|
||||
self.assertEqual(c.op.control_inputs, [a.op])
|
||||
self.assertEqual(d.op.control_inputs, [a.op])
|
||||
# e should be dominated by c.
|
||||
self.assertEqual(e.op.control_inputs, [])
|
||||
finally:
|
||||
ops._USE_C_API = False # pylint: disable=protected-access
|
||||
self.assertEqual(c.op.control_inputs, [a.op])
|
||||
self.assertEqual(d.op.control_inputs, [a.op])
|
||||
# e should be dominated by c.
|
||||
self.assertEqual(e.op.control_inputs, [])
|
||||
|
||||
def testBasicWithConversion(self):
|
||||
g = ops.Graph()
|
||||
|
@ -227,6 +227,19 @@ def NCHWToNHWC(input_tensor):
|
||||
return [input_tensor[a] for a in new_axes[ndims]]
|
||||
|
||||
|
||||
# TODO(skyewm): remove this eventually
|
||||
# pylint: disable=protected-access
|
||||
def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
|
||||
prev_value = ops._USE_C_API
|
||||
ops._USE_C_API = use_c_api
|
||||
try:
|
||||
with ops.Graph().as_default():
|
||||
fn(*args, **kwargs)
|
||||
finally:
|
||||
ops._USE_C_API = prev_value
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
# TODO(skyewm): remove this eventually
|
||||
def disable_c_api(fn):
|
||||
"""Decorator for disabling the C API on a test.
|
||||
@ -240,17 +253,23 @@ def disable_c_api(fn):
|
||||
Returns:
|
||||
The wrapped function
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
def disable_c_api_wrapper(*args, **kwargs):
|
||||
prev_value = ops._USE_C_API
|
||||
ops._USE_C_API = False
|
||||
try:
|
||||
with ops.Graph().as_default():
|
||||
fn(*args, **kwargs)
|
||||
finally:
|
||||
ops._USE_C_API = prev_value
|
||||
# pylint: disable=protected-access
|
||||
return disable_c_api_wrapper
|
||||
return lambda *args, **kwargs: _use_c_api_wrapper(fn, False, *args, **kwargs)
|
||||
|
||||
|
||||
# TODO(skyewm): remove this eventually
|
||||
def enable_c_api(fn):
|
||||
"""Decorator for enabling the C API on a test.
|
||||
|
||||
Note this enables the C API after running the test class's setup/teardown
|
||||
methods.
|
||||
|
||||
Args:
|
||||
fn: the function to be wrapped
|
||||
|
||||
Returns:
|
||||
The wrapped function
|
||||
"""
|
||||
return lambda *args, **kwargs: _use_c_api_wrapper(fn, True, *args, **kwargs)
|
||||
|
||||
|
||||
class TensorFlowTestCase(googletest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user