Introduce Python-only extensions to the C API
Implements an incomplete version of Operation._add_control_input() using a new extension to make sure the plumbing works. This also adds header guards to c_api_internal.h, which were missing. For some reason the missing guards caused problems in the cmake build even though there doesn't appear to be any #include cycles. PiperOrigin-RevId: 161705859
This commit is contained in:
parent
4f54336348
commit
45a58d378e
@ -144,6 +144,19 @@ tf_custom_op_library(
|
|||||||
srcs = ["test_op.cc"],
|
srcs = ["test_op.cc"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Python API target
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "python_api",
|
||||||
|
srcs = ["python_api.cc"],
|
||||||
|
hdrs = ["python_api.h"],
|
||||||
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
|
deps = [
|
||||||
|
":c_api_internal",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Google-internal targets.
|
# Google-internal targets.
|
||||||
|
|
||||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_C_API_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_C_API_INTERNAL_H_
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -125,3 +128,5 @@ class TensorCApi {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
|
||||||
|
28
tensorflow/c/python_api.cc
Normal file
28
tensorflow/c/python_api.cc
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/python_api.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
|
||||||
|
// TODO(skyewm): make sure cycles are prevented
|
||||||
|
mutex_lock l(graph->mu);
|
||||||
|
graph->graph.AddControlEdge(&input->node, &op->node);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
30
tensorflow/c/python_api.h
Normal file
30
tensorflow/c/python_api.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
|
||||||
|
// These functions can be removed without notice. They exist to facilitate some
|
||||||
|
// refactoring of graph construction code in the Python API.
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
|
@ -26,3 +26,9 @@ set(tf_c_srcs
|
|||||||
|
|
||||||
add_library(tf_c OBJECT ${tf_c_srcs})
|
add_library(tf_c OBJECT ${tf_c_srcs})
|
||||||
add_dependencies(tf_c tf_cc_framework tf_core_lib tf_protos_cc)
|
add_dependencies(tf_c tf_cc_framework tf_core_lib tf_protos_cc)
|
||||||
|
|
||||||
|
add_library(tf_c_python_api OBJECT
|
||||||
|
"${tensorflow_source_dir}/tensorflow/c/python_api.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/c/python_api.h"
|
||||||
|
)
|
||||||
|
add_dependencies(tf_c_python_api tf_c tf_cc_framework tf_core_lib tf_protos_cc)
|
||||||
|
@ -761,6 +761,7 @@ if(WIN32)
|
|||||||
add_library(pywrap_tensorflow_internal_static STATIC
|
add_library(pywrap_tensorflow_internal_static STATIC
|
||||||
${pywrap_tensorflow_internal_src}
|
${pywrap_tensorflow_internal_src}
|
||||||
$<TARGET_OBJECTS:tf_c>
|
$<TARGET_OBJECTS:tf_c>
|
||||||
|
$<TARGET_OBJECTS:tf_c_python_api>
|
||||||
$<TARGET_OBJECTS:tf_core_lib>
|
$<TARGET_OBJECTS:tf_core_lib>
|
||||||
$<TARGET_OBJECTS:tf_core_cpu>
|
$<TARGET_OBJECTS:tf_core_cpu>
|
||||||
$<TARGET_OBJECTS:tf_core_framework>
|
$<TARGET_OBJECTS:tf_core_framework>
|
||||||
@ -809,6 +810,7 @@ endif(WIN32)
|
|||||||
add_library(pywrap_tensorflow_internal SHARED
|
add_library(pywrap_tensorflow_internal SHARED
|
||||||
${pywrap_tensorflow_internal_src}
|
${pywrap_tensorflow_internal_src}
|
||||||
$<TARGET_OBJECTS:tf_c>
|
$<TARGET_OBJECTS:tf_c>
|
||||||
|
$<TARGET_OBJECTS:tf_c_python_api>
|
||||||
$<TARGET_OBJECTS:tf_core_lib>
|
$<TARGET_OBJECTS:tf_core_lib>
|
||||||
$<TARGET_OBJECTS:tf_core_cpu>
|
$<TARGET_OBJECTS:tf_core_cpu>
|
||||||
$<TARGET_OBJECTS:tf_core_framework>
|
$<TARGET_OBJECTS:tf_core_framework>
|
||||||
|
@ -2787,6 +2787,7 @@ tf_py_wrap_cc(
|
|||||||
":tf_session_helper",
|
":tf_session_helper",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:checkpoint_reader",
|
"//tensorflow/c:checkpoint_reader",
|
||||||
|
"//tensorflow/c:python_api",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
%{
|
%{
|
||||||
|
|
||||||
|
#include "tensorflow/c/python_api.h"
|
||||||
#include "tensorflow/python/client/tf_session_helper.h"
|
#include "tensorflow/python/client/tf_session_helper.h"
|
||||||
#include "tensorflow/core/framework/session_state.h"
|
#include "tensorflow/core/framework/session_state.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -269,6 +270,7 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
|
|||||||
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
|
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
|
||||||
|
|
||||||
%include "tensorflow/c/c_api.h"
|
%include "tensorflow/c/c_api.h"
|
||||||
|
%include "tensorflow/c/python_api.h"
|
||||||
|
|
||||||
%ignoreall
|
%ignoreall
|
||||||
%insert("python") %{
|
%insert("python") %{
|
||||||
|
@ -1527,8 +1527,9 @@ class Operation(object):
|
|||||||
TypeError: if op is not an Operation.
|
TypeError: if op is not an Operation.
|
||||||
ValueError: if op is from a different graph.
|
ValueError: if op is from a different graph.
|
||||||
"""
|
"""
|
||||||
assert not self._graph._c_graph, ( # pylint: disable=protected-access
|
if _USE_C_API:
|
||||||
"Operation._add_control_input doesn't work with C API")
|
c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
self._add_control_inputs([op])
|
self._add_control_inputs([op])
|
||||||
|
|
||||||
# Methods below are used when building the NodeDef and Graph proto.
|
# Methods below are used when building the NodeDef and Graph proto.
|
||||||
|
@ -384,6 +384,15 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIsInstance(x, dtypes.DType)
|
self.assertIsInstance(x, dtypes.DType)
|
||||||
self.assertEqual([dtypes.string, dtypes.double], l)
|
self.assertEqual([dtypes.string, dtypes.double], l)
|
||||||
|
|
||||||
|
# TODO(skyewm): test adding cycles, other error cases
|
||||||
|
@test_util.enable_c_api
|
||||||
|
def testAddControlInput(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
x = constant_op.constant(1).op
|
||||||
|
y = constant_op.constant(2).op
|
||||||
|
y._add_control_input(x) # pylint: disable=protected-access
|
||||||
|
self.assertEqual(y.control_inputs, [x])
|
||||||
|
|
||||||
|
|
||||||
class CreateOpTest(test_util.TensorFlowTestCase):
|
class CreateOpTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@ -1062,9 +1071,8 @@ class ComparisonTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.enable_c_api
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
ops._USE_C_API = True # pylint: disable=protected-access
|
|
||||||
try:
|
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
# Creating unregistered ops with _apply_op() doesn't work with the C API
|
# Creating unregistered ops with _apply_op() doesn't work with the C API
|
||||||
@ -1083,8 +1091,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(d.op.control_inputs, [a.op])
|
self.assertEqual(d.op.control_inputs, [a.op])
|
||||||
# e should be dominated by c.
|
# e should be dominated by c.
|
||||||
self.assertEqual(e.op.control_inputs, [])
|
self.assertEqual(e.op.control_inputs, [])
|
||||||
finally:
|
|
||||||
ops._USE_C_API = False # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def testBasicWithConversion(self):
|
def testBasicWithConversion(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
|
@ -227,6 +227,19 @@ def NCHWToNHWC(input_tensor):
|
|||||||
return [input_tensor[a] for a in new_axes[ndims]]
|
return [input_tensor[a] for a in new_axes[ndims]]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(skyewm): remove this eventually
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
|
||||||
|
prev_value = ops._USE_C_API
|
||||||
|
ops._USE_C_API = use_c_api
|
||||||
|
try:
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
fn(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
ops._USE_C_API = prev_value
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
# TODO(skyewm): remove this eventually
|
# TODO(skyewm): remove this eventually
|
||||||
def disable_c_api(fn):
|
def disable_c_api(fn):
|
||||||
"""Decorator for disabling the C API on a test.
|
"""Decorator for disabling the C API on a test.
|
||||||
@ -240,17 +253,23 @@ def disable_c_api(fn):
|
|||||||
Returns:
|
Returns:
|
||||||
The wrapped function
|
The wrapped function
|
||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access
|
return lambda *args, **kwargs: _use_c_api_wrapper(fn, False, *args, **kwargs)
|
||||||
def disable_c_api_wrapper(*args, **kwargs):
|
|
||||||
prev_value = ops._USE_C_API
|
|
||||||
ops._USE_C_API = False
|
# TODO(skyewm): remove this eventually
|
||||||
try:
|
def enable_c_api(fn):
|
||||||
with ops.Graph().as_default():
|
"""Decorator for enabling the C API on a test.
|
||||||
fn(*args, **kwargs)
|
|
||||||
finally:
|
Note this enables the C API after running the test class's setup/teardown
|
||||||
ops._USE_C_API = prev_value
|
methods.
|
||||||
# pylint: disable=protected-access
|
|
||||||
return disable_c_api_wrapper
|
Args:
|
||||||
|
fn: the function to be wrapped
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The wrapped function
|
||||||
|
"""
|
||||||
|
return lambda *args, **kwargs: _use_c_api_wrapper(fn, True, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TensorFlowTestCase(googletest.TestCase):
|
class TensorFlowTestCase(googletest.TestCase):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user