Introduce Python-only extensions to the C API

Implements an incomplete version of Operation._add_control_input()
using a new extension to make sure the plumbing works.

This also adds header guards to c_api_internal.h, which were missing. For some reason the missing guards caused problems in the cmake build even though there doesn't appear to be any #include cycles.

PiperOrigin-RevId: 161705859
This commit is contained in:
Skye Wanderman-Milne 2017-07-12 13:08:51 -07:00 committed by TensorFlower Gardener
parent 4f54336348
commit 45a58d378e
11 changed files with 148 additions and 35 deletions

View File

@ -144,6 +144,19 @@ tf_custom_op_library(
srcs = ["test_op.cc"],
)
# -----------------------------------------------------------------------------
# Python API target
tf_cuda_library(
name = "python_api",
srcs = ["python_api.cc"],
hdrs = ["python_api.h"],
visibility = ["//tensorflow/python:__pkg__"],
deps = [
":c_api_internal",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets.

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_API_INTERNAL_H_
#define TENSORFLOW_C_C_API_INTERNAL_H_
#include "tensorflow/c/c_api.h"
#include <vector>
@ -125,3 +128,5 @@ class TensorCApi {
};
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_

View File

@ -0,0 +1,28 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/python_api.h"
#include "tensorflow/c/c_api_internal.h"
namespace tensorflow {
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
// TODO(skyewm): make sure cycles are prevented
mutex_lock l(graph->mu);
graph->graph.AddControlEdge(&input->node, &op->node);
}
} // namespace tensorflow

30
tensorflow/c/python_api.h Normal file
View File

@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
#define THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
#include "tensorflow/c/c_api.h"
// These functions can be removed without notice. They exist to facilitate some
// refactoring of graph construction code in the Python API.
namespace tensorflow {
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_

View File

@ -26,3 +26,9 @@ set(tf_c_srcs
add_library(tf_c OBJECT ${tf_c_srcs})
add_dependencies(tf_c tf_cc_framework tf_core_lib tf_protos_cc)
add_library(tf_c_python_api OBJECT
"${tensorflow_source_dir}/tensorflow/c/python_api.cc"
"${tensorflow_source_dir}/tensorflow/c/python_api.h"
)
add_dependencies(tf_c_python_api tf_c tf_cc_framework tf_core_lib tf_protos_cc)

View File

@ -761,6 +761,7 @@ if(WIN32)
add_library(pywrap_tensorflow_internal_static STATIC
${pywrap_tensorflow_internal_src}
$<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_c_python_api>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
@ -809,6 +810,7 @@ endif(WIN32)
add_library(pywrap_tensorflow_internal SHARED
${pywrap_tensorflow_internal_src}
$<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_c_python_api>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>

View File

@ -2787,6 +2787,7 @@ tf_py_wrap_cc(
":tf_session_helper",
"//tensorflow/c:c_api",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/c:python_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",

View File

@ -17,6 +17,7 @@ limitations under the License.
%{
#include "tensorflow/c/python_api.h"
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/lib/core/errors.h"
@ -269,6 +270,7 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
%include "tensorflow/c/c_api.h"
%include "tensorflow/c/python_api.h"
%ignoreall
%insert("python") %{

View File

@ -1527,8 +1527,9 @@ class Operation(object):
TypeError: if op is not an Operation.
ValueError: if op is from a different graph.
"""
assert not self._graph._c_graph, ( # pylint: disable=protected-access
"Operation._add_control_input doesn't work with C API")
if _USE_C_API:
c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
else:
self._add_control_inputs([op])
# Methods below are used when building the NodeDef and Graph proto.

View File

@ -384,6 +384,15 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertIsInstance(x, dtypes.DType)
self.assertEqual([dtypes.string, dtypes.double], l)
# TODO(skyewm): test adding cycles, other error cases
@test_util.enable_c_api
def testAddControlInput(self):
with ops.Graph().as_default():
x = constant_op.constant(1).op
y = constant_op.constant(2).op
y._add_control_input(x) # pylint: disable=protected-access
self.assertEqual(y.control_inputs, [x])
class CreateOpTest(test_util.TensorFlowTestCase):
@ -1062,9 +1071,8 @@ class ComparisonTest(test_util.TensorFlowTestCase):
class ControlDependenciesTest(test_util.TensorFlowTestCase):
@test_util.enable_c_api
def testBasic(self):
ops._USE_C_API = True # pylint: disable=protected-access
try:
g = ops.Graph()
with g.as_default():
# Creating unregistered ops with _apply_op() doesn't work with the C API
@ -1083,8 +1091,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
self.assertEqual(d.op.control_inputs, [a.op])
# e should be dominated by c.
self.assertEqual(e.op.control_inputs, [])
finally:
ops._USE_C_API = False # pylint: disable=protected-access
def testBasicWithConversion(self):
g = ops.Graph()

View File

@ -227,6 +227,19 @@ def NCHWToNHWC(input_tensor):
return [input_tensor[a] for a in new_axes[ndims]]
# TODO(skyewm): remove this eventually
# pylint: disable=protected-access
def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
prev_value = ops._USE_C_API
ops._USE_C_API = use_c_api
try:
with ops.Graph().as_default():
fn(*args, **kwargs)
finally:
ops._USE_C_API = prev_value
# pylint: disable=protected-access
# TODO(skyewm): remove this eventually
def disable_c_api(fn):
"""Decorator for disabling the C API on a test.
@ -240,17 +253,23 @@ def disable_c_api(fn):
Returns:
The wrapped function
"""
# pylint: disable=protected-access
def disable_c_api_wrapper(*args, **kwargs):
prev_value = ops._USE_C_API
ops._USE_C_API = False
try:
with ops.Graph().as_default():
fn(*args, **kwargs)
finally:
ops._USE_C_API = prev_value
# pylint: disable=protected-access
return disable_c_api_wrapper
return lambda *args, **kwargs: _use_c_api_wrapper(fn, False, *args, **kwargs)
# TODO(skyewm): remove this eventually
def enable_c_api(fn):
"""Decorator for enabling the C API on a test.
Note this enables the C API after running the test class's setup/teardown
methods.
Args:
fn: the function to be wrapped
Returns:
The wrapped function
"""
return lambda *args, **kwargs: _use_c_api_wrapper(fn, True, *args, **kwargs)
class TensorFlowTestCase(googletest.TestCase):