Add PythonTensorConverter class, which can be used in c++ to efficiently convert PyObjects to tensors.
PiperOrigin-RevId: 339155091 Change-Id: Icad20253f523e3ac3685d8d34c52e940f3889b57
This commit is contained in:
parent
08e92a07fe
commit
3ac0818dea
tensorflow
@ -110,6 +110,7 @@ filegroup(
|
||||
"abstract_function.h",
|
||||
"abstract_operation.h",
|
||||
"abstract_tensor_handle.h",
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
|
@ -1777,6 +1777,79 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_tensor_converter",
|
||||
srcs = ["framework/python_tensor_converter.cc"],
|
||||
hdrs = ["framework/python_tensor_converter.h"],
|
||||
deps = [
|
||||
":cpp_python_util",
|
||||
":safe_pyobject_ptr",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: this target is only used by python_tensor_converter_test.
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_python_tensor_converter",
|
||||
srcs = ["framework/python_tensor_converter_wrapper.cc"],
|
||||
hdrs = [
|
||||
"framework/python_tensor_converter.h",
|
||||
"lib/core/numpy.h",
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
|
||||
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/python/eager:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_python_tensor_converter",
|
||||
deps = [
|
||||
":safe_pyobject_ptr_required_hdrs",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@pybind11",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
|
||||
"//tensorflow/core:lib_headers_for_pybind",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//tensorflow/c:pywrap_required_hdrs",
|
||||
"@com_google_absl//absl/types:span",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/core/protobuf:eager_service_proto_cc",
|
||||
"//tensorflow/core/protobuf:master_proto_cc",
|
||||
"//tensorflow/core/protobuf:worker_proto_cc",
|
||||
],
|
||||
otherwise = [
|
||||
"//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
|
||||
"//tensorflow/core/protobuf:master_proto_cc_headers_only",
|
||||
"//tensorflow/core/protobuf:worker_proto_cc_headers_only",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "python_tensor_converter_test",
|
||||
srcs = ["framework/python_tensor_converter_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":_pywrap_python_tensor_converter",
|
||||
":client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "framework_ops", # "ops" is already the name of a deprecated target
|
||||
srcs = ["framework/ops.py"],
|
||||
@ -6026,6 +6099,7 @@ pywrap_tensorflow_macro(
|
||||
":pybind11_proto",
|
||||
":python_api_dispatcher",
|
||||
":python_op_gen",
|
||||
":python_tensor_converter",
|
||||
":safe_pyobject_ptr",
|
||||
":tf_session_helper",
|
||||
"//third_party/python_runtime:headers",
|
||||
@ -6096,6 +6170,7 @@ filegroup(
|
||||
":py_exception_registry", # py_exception_registry
|
||||
":py_func_lib", # py_func
|
||||
":python_api_dispatcher", # python_api_dispatcher
|
||||
":python_tensor_converter", # python_tensor_converter
|
||||
":python_op_gen", # python_op_gen
|
||||
":safe_ptr", # checkpoint_reader
|
||||
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
|
||||
|
@ -72,6 +72,7 @@ cc_library(
|
||||
"//tensorflow/python:py_seq_tensor",
|
||||
"//tensorflow/python:py_util",
|
||||
"//tensorflow/python:safe_ptr",
|
||||
"//tensorflow/python:safe_pyobject_ptr",
|
||||
"//tensorflow/python:stack_trace",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
|
136
tensorflow/python/framework/python_tensor_converter.cc
Normal file
136
tensorflow/python/framework/python_tensor_converter.cc
Normal file
@ -0,0 +1,136 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/python/framework/python_tensor_converter.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/python/eager/pywrap_tensor.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
// Python 2.x:
|
||||
#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
|
||||
#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
|
||||
#else
|
||||
// Python 3.x:
|
||||
#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
|
||||
#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Returns `tensor.dtype._type_enum` as a DataType enum. Assumes that `tensor`
|
||||
// is a python `Tensor` object.
|
||||
//
|
||||
// On error: sets a python AttributeError exception and returns DT_INVALID.
|
||||
DataType DataTypeForTensor(PyObject* tensor) {
|
||||
static PyObject* dtype_attr = PY_STRING_INTERN_FROM_STRING("dtype");
|
||||
static PyObject* type_enum_attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
|
||||
|
||||
Safe_PyObjectPtr py_dtype(PyObject_GetAttr(tensor, dtype_attr));
|
||||
if (!py_dtype) return DT_INVALID;
|
||||
|
||||
Safe_PyObjectPtr enum_field(PyObject_GetAttr(py_dtype.get(), type_enum_attr));
|
||||
if (!enum_field) return DT_INVALID;
|
||||
|
||||
DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get()));
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check that actual_dtype == expected_dtype. If not, set an exception and
|
||||
// return false. (If expected_dtype is DT_INVALID, then instead simply update
|
||||
// its value to `actual_dtype` and return true.)
|
||||
bool CheckDType(DataType actual_dtype, DataType& expected_dtype) {
|
||||
if (expected_dtype == DT_INVALID) {
|
||||
expected_dtype = actual_dtype; // set output parameter.
|
||||
} else if (expected_dtype != actual_dtype) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
absl::StrCat("Expected ", DataType_Name(expected_dtype),
|
||||
" but got ", DataType_Name(actual_dtype))
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Safe_PyObjectPtr PythonTensorConverter::Convert(PyObject* src, DataType& dtype,
|
||||
bool* used_fallback) const {
|
||||
// First, try converting `src` to a Tensor without calling back into Python.
|
||||
if (ctx_) { // Eager mode
|
||||
// TODO(b/164980194): Handle resource variables as well. (See
|
||||
// ConvertToTensor function in pywrap_tfe_src.cc).
|
||||
if (EagerTensor_CheckExact(src)) {
|
||||
// `src` is already an eager tensor; check its type, and return it as-is.
|
||||
if (!CheckDType(PyEagerTensor_Dtype(src), dtype)) return nullptr;
|
||||
Py_INCREF(src);
|
||||
return Safe_PyObjectPtr(src);
|
||||
} else {
|
||||
TFE_TensorHandle* handle =
|
||||
tensorflow::ConvertToEagerTensor(ctx_, src, dtype, device_name_);
|
||||
if (handle) {
|
||||
Safe_PyObjectPtr result(EagerTensorFromHandle(handle));
|
||||
if (!CheckDType(PyEagerTensor_Dtype(result.get()), dtype)) {
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
PyErr_Clear();
|
||||
}
|
||||
}
|
||||
} else { // Graph mode
|
||||
if (swig::IsTensor(src)) {
|
||||
DataType src_dtype = DataTypeForTensor(src);
|
||||
if (src_dtype == DT_INVALID) return nullptr;
|
||||
if (!CheckDType(src_dtype, dtype)) return nullptr;
|
||||
Py_INCREF(src);
|
||||
return Safe_PyObjectPtr(src);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: use the Python tf.convert_to_tensor function.
|
||||
// Currently this is used:
|
||||
//
|
||||
// * In Eager mode: for anything that's not already an Eager tensor, or
|
||||
// handled by `tensorflow::ConvertToEagerTensor`. (At time of writing
|
||||
// for this comment, ConvertToEagerTensor handles simple values like ints,
|
||||
// nested lists of simple values, and numpy arrays.)
|
||||
// * In graph mode: for anything that's not already a tensor.
|
||||
//
|
||||
// TODO(b/164980194) Reduce/eliminate cases where fallback is used.
|
||||
if (used_fallback) *used_fallback = true;
|
||||
static PyObject* convert_to_tensor =
|
||||
swig::GetRegisteredPyObject("tf.convert_to_tensor");
|
||||
if (!convert_to_tensor) return nullptr;
|
||||
|
||||
Safe_PyObjectPtr args(PyTuple_New(dtype == DT_INVALID ? 1 : 2));
|
||||
Safe_PyObjectPtr kwargs(PyDict_New());
|
||||
Py_INCREF(src);
|
||||
PyTuple_SetItem(args.get(), 0, src);
|
||||
if (dtype != DT_INVALID) {
|
||||
PyTuple_SetItem(args.get(), 1, PyLong_FromLong(dtype));
|
||||
}
|
||||
PyDict_SetItemString(kwargs.get(), "ctx", py_eager_context_);
|
||||
Safe_PyObjectPtr result(
|
||||
PyObject_Call(convert_to_tensor, args.get(), kwargs.get()));
|
||||
if (!result) return nullptr;
|
||||
dtype = DataTypeForTensor(result.get()); // set output parameter.
|
||||
if (dtype == DT_INVALID) return nullptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
76
tensorflow/python/framework/python_tensor_converter.h
Normal file
76
tensorflow/python/framework/python_tensor_converter.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
|
||||
#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts PyObject* values to Tensors.
|
||||
//
|
||||
// This converter attempts to convert values as efficiently as possible; but
|
||||
// it has fallback paths to handle any PyObject* value for which tensor
|
||||
// conversion is defined.
|
||||
class PythonTensorConverter {
|
||||
public:
|
||||
// Constructs a new PythonTensorConverter.
|
||||
//
|
||||
// Note: the arguments to this constructor may change in the future, as
|
||||
// we move more of python tensor conversion from the Python layer to the
|
||||
// c++ layer.
|
||||
//
|
||||
// Args:
|
||||
// py_eager_context: the value of context.context() from eager/context.py.
|
||||
// ctx: The c++ eager context, or nullptr in graph mode.
|
||||
// device_name: The current device name.
|
||||
//
|
||||
// All three argument values must remain alive until `this` is deleted.
|
||||
PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx,
|
||||
const char* device_name)
|
||||
: py_eager_context_(py_eager_context),
|
||||
ctx_(ctx),
|
||||
device_name_(device_name) {}
|
||||
|
||||
// Converts `src` to a tensor (if it's not already one), and returns a new
|
||||
// reference to the converted value.
|
||||
//
|
||||
// Args:
|
||||
// src: The object that should be converted to a Tensor.
|
||||
// dtype: The requested dtype. Use `DT_INVALID` if the dtype should be
|
||||
// inferred from the `src` value (in which case `dtype` will be updated
|
||||
// in-place to be the actual dtype of the converted value).
|
||||
// used_fallback: Output parameter used to record whether the conversion
|
||||
// was done by falling back to the Python `tf.convert_to_tensor()`
|
||||
// function. This is for testing/logging purposes only. May be null.
|
||||
//
|
||||
// If `src` can't be converted to a tensor with the requested dtype, sets a
|
||||
// Python exception and returns nullptr.
|
||||
Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype,
|
||||
bool* used_fallback = nullptr) const;
|
||||
|
||||
private:
|
||||
PyObject* py_eager_context_;
|
||||
TFE_Context* ctx_;
|
||||
const char* device_name_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
|
208
tensorflow/python/framework/python_tensor_converter_test.py
Normal file
208
tensorflow/python/framework/python_tensor_converter_test.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.framework.python_tensor_converter."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python import _pywrap_python_tensor_converter
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PythonTensorConverterTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
context.ensure_initialized()
|
||||
super(PythonTensorConverterTest, self).setUp()
|
||||
|
||||
def makePythonTensorConverter(self):
|
||||
return _pywrap_python_tensor_converter.PythonTensorConverter(
|
||||
context.context())
|
||||
|
||||
#=============================================================================
|
||||
# Convert int to tensor.
|
||||
|
||||
def testConvertIntWithInferredDType(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INVALID)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, 12)
|
||||
self.assertEqual(dtype, types_pb2.DT_INT32)
|
||||
self.assertEqual(used_fallback, not context.executing_eagerly())
|
||||
|
||||
def testConvertIntWithExplicitDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INT64)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, 12)
|
||||
self.assertEqual(dtype, types_pb2.DT_INT64)
|
||||
self.assertEqual(used_fallback, not context.executing_eagerly())
|
||||
|
||||
def testConvertIntWithIncompatibleDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Expected string, got 3 of type 'int' instead."
|
||||
"|Cannot convert 3 to EagerTensor of dtype string"):
|
||||
converter.Convert(3, types_pb2.DT_STRING)
|
||||
|
||||
#=============================================================================
|
||||
# Convert tensor to tensor.
|
||||
|
||||
def testConvertTensorWithInferredDType(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
result, dtype, used_fallback = converter.Convert(
|
||||
constant_op.constant([1, 2, 3]), types_pb2.DT_INVALID)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [1, 2, 3])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT32)
|
||||
self.assertFalse(used_fallback)
|
||||
|
||||
def testConvertTensorWithExplicitDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
result, dtype, used_fallback = converter.Convert(
|
||||
constant_op.constant([1, 2, 3], dtypes.int64), types_pb2.DT_INT64)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [1, 2, 3])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT64)
|
||||
self.assertFalse(used_fallback)
|
||||
|
||||
def testConvertTensorWithIncorrectDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
with self.assertRaises((TypeError, ValueError)):
|
||||
converter.Convert(
|
||||
constant_op.constant([1, 2, 3], dtypes.int32), types_pb2.DT_INT64)
|
||||
|
||||
#=============================================================================
|
||||
# Convert list to tensor.
|
||||
|
||||
def testConvertListWithInferredDType(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]],
|
||||
types_pb2.DT_INVALID)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT32)
|
||||
self.assertEqual(used_fallback, not context.executing_eagerly())
|
||||
|
||||
def testConvertListWithExplicitDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]],
|
||||
types_pb2.DT_INT64)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT64)
|
||||
self.assertEqual(used_fallback, not context.executing_eagerly())
|
||||
|
||||
def testConvertListWithIncompatibleDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Expected string, got .* of type 'int' instead."
|
||||
"|Cannot convert .* to EagerTensor of dtype string"):
|
||||
converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_STRING)
|
||||
|
||||
def testConvertListWithInconsistentDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
with self.assertRaisesRegex(
|
||||
(TypeError, ValueError),
|
||||
"Can't convert Python sequence with mixed types to Tensor."
|
||||
"|Failed to convert"):
|
||||
converter.Convert([[1, 2], ["a", "b"]], types_pb2.DT_INVALID)
|
||||
|
||||
#=============================================================================
|
||||
# Convert np.array to tensor.
|
||||
|
||||
def testConvertNumpyArrayWithInferredDType(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
|
||||
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT32)
|
||||
self.assertEqual(used_fallback, not context.executing_eagerly())
|
||||
|
||||
def testConvertNumpyArrayWithExplicitDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
|
||||
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT64)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT64)
|
||||
self.assertEqual(used_fallback, not context.executing_eagerly())
|
||||
|
||||
def testConvertNumpyArrayWithIncompatibleDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
|
||||
with self.assertRaises((ValueError, TypeError)):
|
||||
converter.Convert(x, types_pb2.DT_STRING)
|
||||
|
||||
def testConvertNumpyArrayWithUnsupportedDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
x = np.array([[1, 2], ["a", "b"]], np.object)
|
||||
with self.assertRaises((ValueError, TypeError)):
|
||||
converter.Convert(x, types_pb2.DT_INVALID)
|
||||
|
||||
#=============================================================================
|
||||
# Convert IndexedSlices to tensor.
|
||||
|
||||
def testConvertIndexedSlicesWithInferredDType(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
x = indexed_slices.IndexedSlices(
|
||||
constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
|
||||
constant_op.constant([1], dtypes.int64, name="x_indices"),
|
||||
constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
|
||||
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT32)
|
||||
self.assertTrue(used_fallback)
|
||||
|
||||
def testConvertIndexedSlicesWithExplicitDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
x = indexed_slices.IndexedSlices(
|
||||
constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
|
||||
constant_op.constant([1], dtypes.int64, name="x_indices"),
|
||||
constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
|
||||
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT32)
|
||||
self.assertIsInstance(result, ops.Tensor)
|
||||
self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]])
|
||||
self.assertEqual(dtype, types_pb2.DT_INT32)
|
||||
self.assertTrue(used_fallback)
|
||||
|
||||
def testConvertIndexedSlicesWithIncorrectDtype(self):
|
||||
converter = self.makePythonTensorConverter()
|
||||
x = indexed_slices.IndexedSlices(
|
||||
constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
|
||||
constant_op.constant([1], dtypes.int64, name="x_indices"),
|
||||
constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
|
||||
with self.assertRaises((ValueError, TypeError)):
|
||||
converter.Convert(x, types_pb2.DT_FLOAT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
120
tensorflow/python/framework/python_tensor_converter_wrapper.cc
Normal file
120
tensorflow/python/framework/python_tensor_converter_wrapper.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Note: This library is only used by python_tensor_converter_test. It is
|
||||
// not meant to be used in other circumstances.
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/framework/python_tensor_converter.h"
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
// Python 2.x:
|
||||
#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
|
||||
#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
|
||||
#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
|
||||
#else
|
||||
// Python 3.x:
|
||||
#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
|
||||
#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
|
||||
#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
|
||||
#endif
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Safe_PyObjectPtr GetAttr_ThreadLocalData(PyObject* eager_context) {
|
||||
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_thread_local_data");
|
||||
return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr GetAttr_ContextHandle(PyObject* eager_context) {
|
||||
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_context_handle");
|
||||
return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr GetAttr_IsEager(PyObject* tld) {
|
||||
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("is_eager");
|
||||
return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr GetAttr_DeviceName(PyObject* tld) {
|
||||
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("device_name");
|
||||
return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) {
|
||||
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
|
||||
return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr));
|
||||
}
|
||||
|
||||
PythonTensorConverter MakePythonTensorConverter(py::handle py_eager_context) {
|
||||
Safe_PyObjectPtr tld = GetAttr_ThreadLocalData(py_eager_context.ptr());
|
||||
if (!tld) throw py::error_already_set();
|
||||
|
||||
Safe_PyObjectPtr py_is_eager = GetAttr_IsEager(tld.get());
|
||||
if (!py_is_eager) throw py::error_already_set();
|
||||
bool is_eager = PyObject_IsTrue(py_is_eager.get());
|
||||
|
||||
// Initialize the eager context, if necessary.
|
||||
TFE_Context* ctx = nullptr;
|
||||
const char* device_name = nullptr;
|
||||
if (is_eager) {
|
||||
Safe_PyObjectPtr context_handle =
|
||||
GetAttr_ContextHandle(py_eager_context.ptr());
|
||||
if (!context_handle) throw py::error_already_set();
|
||||
if (context_handle.get() == Py_None) {
|
||||
throw std::runtime_error("Error retrieving context handle.");
|
||||
}
|
||||
Safe_PyObjectPtr py_device_name = GetAttr_DeviceName(tld.get());
|
||||
if (!py_device_name) {
|
||||
throw std::runtime_error("Error retrieving device name.");
|
||||
}
|
||||
device_name = TFE_GetPythonString(py_device_name.get());
|
||||
ctx = reinterpret_cast<TFE_Context*>(
|
||||
PyCapsule_GetPointer(context_handle.get(), nullptr));
|
||||
}
|
||||
|
||||
return PythonTensorConverter(py_eager_context.ptr(), ctx, device_name);
|
||||
}
|
||||
|
||||
py::handle Convert(tensorflow::PythonTensorConverter* self, py::handle obj,
|
||||
py::handle dtype) {
|
||||
DataType dtype_enum = static_cast<DataType>(PY_INT_AS_LONG(dtype.ptr()));
|
||||
bool used_fallback = false;
|
||||
Safe_PyObjectPtr converted =
|
||||
self->Convert(obj.ptr(), dtype_enum, &used_fallback);
|
||||
if (!converted) throw py::error_already_set();
|
||||
|
||||
PyObject* result = PyTuple_New(3);
|
||||
PyTuple_SET_ITEM(result, 0, converted.release());
|
||||
PyTuple_SET_ITEM(result, 1, PY_INT_FROM_LONG(dtype_enum));
|
||||
PyTuple_SET_ITEM(result, 2, used_fallback ? Py_True : Py_False);
|
||||
Py_INCREF(PyTuple_GET_ITEM(result, 1));
|
||||
Py_INCREF(PyTuple_GET_ITEM(result, 2));
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
PYBIND11_MODULE(_pywrap_python_tensor_converter, m) {
|
||||
py::class_<tensorflow::PythonTensorConverter>(m, "PythonTensorConverter")
|
||||
.def(py::init(&tensorflow::MakePythonTensorConverter))
|
||||
.def("Convert", tensorflow::Convert);
|
||||
}
|
@ -400,3 +400,6 @@ tensorflow::TensorHandle::Tensor
|
||||
|
||||
[python_api_dispatcher] # python_api_dispatcher
|
||||
tensorflow::PythonAPIDispatcher
|
||||
|
||||
[python_tensor_converter] # python_tensor_converter
|
||||
tensorflow::PythonTensorConverter
|
||||
|
Loading…
Reference in New Issue
Block a user