Move EagerTensor from python to C.
PiperOrigin-RevId: 170617321
This commit is contained in:
parent
da8349412f
commit
ff18944249
@ -842,6 +842,7 @@ set (pywrap_tensorflow_internal_src
|
||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tensor.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc"
|
||||
|
@ -266,6 +266,7 @@ cc_library(
|
||||
hdrs = ["lib/core/safe_ptr.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
@ -6,7 +6,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
|
||||
cc_library(
|
||||
name = "pywrap_tfe_lib",
|
||||
srcs = ["pywrap_tfe_src.cc"],
|
||||
srcs = [
|
||||
"pywrap_tensor.cc",
|
||||
"pywrap_tfe_src.cc",
|
||||
],
|
||||
hdrs = ["pywrap_tfe.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
@ -14,8 +17,10 @@ cc_library(
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/python:ndarray_tensor",
|
||||
"//tensorflow/python:ndarray_tensor_bridge",
|
||||
"//tensorflow/python:numpy_lib",
|
||||
"//tensorflow/python:py_seq_tensor",
|
||||
"//tensorflow/python:safe_ptr",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.eager import backprop # pylint: disable=unused-import
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
@ -61,18 +62,41 @@ def benchmark_create_tensor(n):
|
||||
def label(s):
|
||||
return "{:20s}".format(s)
|
||||
|
||||
with timer(label("np.array([[3]])"), iters=n) as iters:
|
||||
with timer(label("np.array([[3.0]])"), iters=n) as iters:
|
||||
for _ in iters:
|
||||
np.array([[3]])
|
||||
|
||||
with timer(label("Tensor([[3]])"), iters=n) as iters:
|
||||
for _ in iters:
|
||||
ops.EagerTensor([[3]], context.context())
|
||||
np.array([[3.0]])
|
||||
|
||||
ctx = context.context()
|
||||
with timer(label("Tensor([[3]], ctx)"), iters=n) as iters:
|
||||
handle = ctx._handle
|
||||
device = ctx.device_name
|
||||
# May be warmup GPU.
|
||||
ops.EagerTensor([[3.0]], context=handle, device=device)
|
||||
|
||||
# float32
|
||||
dtype = dtypes.float32.as_datatype_enum
|
||||
three = [[3.0]]
|
||||
with timer(label("EagerTensor([[3.0]])"), iters=n) as iters:
|
||||
for _ in iters:
|
||||
ops.EagerTensor([[3]], ctx)
|
||||
ops.EagerTensor(three, context=handle, device=device, dtype=dtype)
|
||||
|
||||
np_3 = np.array([[3.0]], dtype=np.float32)
|
||||
with timer(label("EagerTensor(np.array([[3.0]]))"), iters=n) as iters:
|
||||
for _ in iters:
|
||||
ops.EagerTensor(np_3, context=handle, device=device, dtype=dtype)
|
||||
|
||||
# int32.
|
||||
# This is interesting since int32 will be kept on host memory for the GPU
|
||||
# case.
|
||||
dtype = dtypes.int32.as_datatype_enum
|
||||
three = [[3]]
|
||||
with timer(label("EagerTensor([[3]])"), iters=n) as iters:
|
||||
for _ in iters:
|
||||
ops.EagerTensor(three, context=handle, device=device, dtype=dtype)
|
||||
|
||||
np_3 = np.array([[3]], dtype=np.int32)
|
||||
with timer(label("EagerTensor(np.array([[3]]))"), iters=n) as iters:
|
||||
for _ in iters:
|
||||
ops.EagerTensor(np_3, context=handle, device=device, dtype=dtype)
|
||||
|
||||
|
||||
def benchmark_matmul(shape, n, use_gpu=False):
|
||||
@ -103,17 +127,16 @@ def benchmark_matmul(shape, n, use_gpu=False):
|
||||
for _ in iters:
|
||||
gen_math_ops._mat_mul(m, m, transpose_b=transpose_b)
|
||||
|
||||
inputs = [m, m]
|
||||
# pylint: disable=protected-access
|
||||
input_handles = [m._handle, m._handle]
|
||||
ctx_handle = context.context()._handle
|
||||
# pylint: enable=protected-access
|
||||
attrs = ("transpose_a", False, "transpose_b", transpose_b, "T",
|
||||
m.dtype.as_datatype_enum)
|
||||
with timer(label("TFE_Py_Execute"), iters=n) as iters:
|
||||
for _ in iters:
|
||||
pywrap_tensorflow.TFE_DeleteTensorHandle(
|
||||
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul",
|
||||
input_handles, attrs, 1)[0])
|
||||
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul",
|
||||
inputs, attrs, 1)
|
||||
|
||||
f = function.defun(math_ops.matmul)
|
||||
with timer(label("defun(tf.matmul)"), iters=n) as iters:
|
||||
@ -133,6 +156,8 @@ class BenchmarksTest(test_util.TensorFlowTestCase):
|
||||
|
||||
if context.context().num_gpus() > 0:
|
||||
print("---- RUNNING ON GPU NOW ----")
|
||||
with context.device("/device:GPU:0"):
|
||||
benchmark_create_tensor(FLAGS.iters or 30000)
|
||||
benchmark_matmul([2, 2], FLAGS.iters or 30000, use_gpu=True)
|
||||
benchmark_matmul([100, 28 * 28], FLAGS.iters or 1000, use_gpu=True)
|
||||
|
||||
|
@ -121,16 +121,6 @@ class Context(object):
|
||||
else:
|
||||
return devices
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
if self._context_handle is not None:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TFE_DeleteContext(self._context_handle, status)
|
||||
except (AttributeError, TypeError):
|
||||
# Sometimes deletion during program shutdown throws exception as other
|
||||
# modules are no longer available.
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
if self._context_handle is None:
|
||||
return "Eager TensorFlow Context. Devices currently uninitialized."
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
@ -138,7 +139,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
x = x.as_cpu_tensor()
|
||||
|
||||
# Invalid device
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaises(RuntimeError):
|
||||
x.as_gpu_tensor(context.context().num_gpus() + 1)
|
||||
|
||||
def testNumpyForceCPU(self):
|
||||
@ -153,7 +154,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
ta = constant_op.constant([[1, 2], [3, 4]])
|
||||
tb = ta.as_cpu_tensor()
|
||||
|
||||
self.assertNotEqual(ta._handle, tb._handle)
|
||||
self.assertNotEqual(id(ta), id(tb))
|
||||
self.assertAllEqual(ta.numpy(), tb.numpy())
|
||||
|
||||
def testRegisterExceptionClass(self):
|
||||
|
@ -53,32 +53,27 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
||||
Raises:
|
||||
An exception on error.
|
||||
"""
|
||||
# TODO(apassos) move this to convert_to_tensor
|
||||
# pylint: disable=protected-access
|
||||
input_handles = [c._handle for c in inputs]
|
||||
device_name = ctx.device_name
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
|
||||
op_name, input_handles, attrs,
|
||||
num_outputs)
|
||||
tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
|
||||
op_name, inputs, attrs,
|
||||
num_outputs)
|
||||
except core._NotOkStatusException as e:
|
||||
if name is not None:
|
||||
message = e.message + " name: " + name
|
||||
else:
|
||||
message = e.message
|
||||
six.raise_from(core._status_to_exception(e.code, message), None)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
tensors = [ops._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
|
||||
# TODO(alive, cais): Use the execution callback mechanism.
|
||||
if core.active_trace() is not None:
|
||||
for t in tensors:
|
||||
# pylint: disable=protected-access
|
||||
core.active_trace().record_tensor(op_name,
|
||||
ops.tensor_id(t),
|
||||
t.device,
|
||||
t.shape.num_elements())
|
||||
# pylint: enable=protected-access
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# TODO(cais): Optimize this, perhaps by replacing this execute function with
|
||||
# a different one when there are execution callback(s).
|
||||
|
@ -162,7 +162,7 @@ def inf_nan_callback(op_type,
|
||||
# TODO(cais): Consider moving this into execute.py.
|
||||
# pylint: disable=protected-access
|
||||
pywrap_tensorflow.TFE_Py_Execute(
|
||||
ctx._handle, output.device, "CheckNumerics", [output._handle],
|
||||
ctx._handle, output.device, "CheckNumerics", [output],
|
||||
check_numerics_op_attrs, 1)
|
||||
# pylint: enable=protected-access
|
||||
except core._NotOkStatusException: # pylint: disable=protected-access
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
|
||||
|
||||
class TargetTest(test_util.TensorFlowTestCase):
|
||||
class OpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testExecuteBasic(self):
|
||||
three = constant_op.constant(3)
|
||||
|
646
tensorflow/python/eager/pywrap_tensor.cc
Normal file
646
tensorflow/python/eager/pywrap_tensor.cc
Normal file
@ -0,0 +1,646 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
|
||||
namespace {
|
||||
|
||||
TFE_Context* GetContext(PyObject* ctx) {
|
||||
TFE_Context* context =
|
||||
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
|
||||
if (context == nullptr) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting a PyCapsule encoded context handle. Got ",
|
||||
Py_TYPE(ctx)->tp_name)
|
||||
.c_str());
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) {
|
||||
tensorflow::Tensor t;
|
||||
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
||||
if (cppstatus.ok()) {
|
||||
return TFE_NewTensorHandle(t);
|
||||
} else {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Failed to convert numpy ndarray to a Tensor (",
|
||||
cppstatus.error_message(), ").")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Casts data referred to by `handle` from type `src_type_enum` to type
|
||||
// `dst_type_enum`.
|
||||
TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||
TF_DataType src_type_enum,
|
||||
TF_DataType dst_type_enum, TF_Status* out_status) {
|
||||
if (ctx == nullptr) return nullptr;
|
||||
const char* op_name = "Cast";
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
|
||||
#define RETURN_ERROR \
|
||||
{ \
|
||||
TFE_DeleteOp(op); \
|
||||
return nullptr; \
|
||||
}
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpSetDevice(op, device_name, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpAddInput(op, handle, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpSetAttrType(op, "SrcT", src_type_enum);
|
||||
TFE_OpSetAttrType(op, "DstT", dst_type_enum);
|
||||
TFE_TensorHandle* output = nullptr;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(op, &output, &num_outputs, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
|
||||
output == nullptr) {
|
||||
if (output != nullptr) {
|
||||
TFE_DeleteTensorHandle(output);
|
||||
}
|
||||
RETURN_ERROR
|
||||
}
|
||||
TFE_DeleteOp(op);
|
||||
return output;
|
||||
#undef RETURN_ERROR
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
|
||||
PyObject* dev) {
|
||||
const char* device = "";
|
||||
if (dev != nullptr && dev != Py_None) {
|
||||
device = PyBytes_AsString(dev);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (device == nullptr) {
|
||||
PyErr_Clear();
|
||||
device = PyUnicode_AsUTF8(dev);
|
||||
}
|
||||
#endif
|
||||
if (device == nullptr) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Error parsing device argument to CopyToDevice");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
TFE_Context* context = GetContext(ctx);
|
||||
if (context == nullptr) { // PyErr already set by GetContext
|
||||
return nullptr;
|
||||
}
|
||||
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_TensorHandle* new_handle =
|
||||
TFE_TensorHandleCopyToDevice(handle, context, device, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
PyErr_SetString(
|
||||
PyExc_RuntimeError,
|
||||
tensorflow::strings::StrCat("Error copying tensor to device: ", device,
|
||||
". ", TF_Message(status.get()))
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return new_handle;
|
||||
}
|
||||
|
||||
// Helper function to convert `v` to an int and store it in `*out`. Returns true
|
||||
// on success, false otherwise.
|
||||
// Note that we assume that v is a python int (not long) representing a
|
||||
// TF_DataType value.
|
||||
bool PyIntToDataType(PyObject* v, int* out) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (PyInt_Check(v)) {
|
||||
*out = PyInt_AS_LONG(v);
|
||||
return true;
|
||||
}
|
||||
#else
|
||||
if (PyLong_Check(v)) {
|
||||
*out = PyLong_AsLong(v);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
// Helper function to create a python integer from TF_DataType.
|
||||
PyObject* PyIntFromDataType(TF_DataType l) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return PyInt_FromLong(l);
|
||||
#else
|
||||
return PyLong_FromLong(l);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
static const int kMaxEagerTensorParentSize = 32;
|
||||
|
||||
// TODO(agarwal): store context handle in EagerTensor.
|
||||
typedef struct EagerTensor {
|
||||
PyObject_HEAD;
|
||||
// Note that we leave kMaxEagerTensorParentSize bytes here for use by the
|
||||
// parent class. The parent class is set at runtime, so we don't know the
|
||||
// exact size at compile time.
|
||||
char unused[kMaxEagerTensorParentSize];
|
||||
TFE_TensorHandle* handle;
|
||||
int64_t id;
|
||||
// This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
||||
// be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
|
||||
// tensors, this will contain a serialized HandleData proto with shape
|
||||
// inference metadata about shapes and dtypes of resources accessible from
|
||||
// this handle.
|
||||
// Note that we assume that handle_data cannot participate in reference
|
||||
// cycles, and hence don't provide GC support for it.
|
||||
PyObject* handle_data;
|
||||
|
||||
// This stores `_keras_mask` object and is set by Tensorflow layers.
|
||||
PyObject* keras_mask;
|
||||
} EagerTensor;
|
||||
|
||||
// tp_init for EagerTensor.
|
||||
int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
||||
self->id = get_uid();
|
||||
self->handle = nullptr;
|
||||
Py_INCREF(Py_None);
|
||||
self->handle_data = Py_None;
|
||||
Py_INCREF(Py_None);
|
||||
self->keras_mask = Py_None;
|
||||
PyObject* value;
|
||||
PyObject* context = nullptr;
|
||||
PyObject* device = nullptr;
|
||||
PyObject* dtype = Py_None;
|
||||
const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
|
||||
const_cast<char**>(kwlist), &value, &context,
|
||||
&device, &dtype)) {
|
||||
return -1;
|
||||
}
|
||||
// Extract dtype
|
||||
int desired_dtype = -1;
|
||||
if (dtype != Py_None) {
|
||||
if (!PyIntToDataType(dtype, &desired_dtype)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting a DataType value for dtype. Got ",
|
||||
Py_TYPE(dtype)->tp_name)
|
||||
.c_str());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
tensorflow::Safe_TFE_TensorHandlePtr handle =
|
||||
tensorflow::make_safe(static_cast<TFE_TensorHandle*>(nullptr));
|
||||
PyErr_Clear();
|
||||
if (PyArray_Check(value)) {
|
||||
int desired_np_dtype = -1;
|
||||
if (desired_dtype >= 0) {
|
||||
if (!tensorflow::TF_DataType_to_PyArray_TYPE(
|
||||
static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
|
||||
.ok()) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Invalid dtype argument value ", desired_dtype)
|
||||
.c_str());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
||||
int current_np_dtype = PyArray_TYPE(array);
|
||||
auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
|
||||
if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
|
||||
!PyArray_ISCARRAY(array)) {
|
||||
int new_dtype =
|
||||
desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
|
||||
safe_value = tensorflow::make_safe(
|
||||
PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
|
||||
NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
|
||||
if (PyErr_Occurred()) return -1;
|
||||
if (safe_value == nullptr) {
|
||||
PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
|
||||
return -1;
|
||||
}
|
||||
value = safe_value.get();
|
||||
}
|
||||
handle = tensorflow::make_safe(NumpyToTensorHandle(value));
|
||||
} else {
|
||||
tensorflow::Tensor t;
|
||||
// TODO(josh11b): Have PySeqToTensor set python errors instead of
|
||||
// returning Status.
|
||||
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
|
||||
if (!cppstatus.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
|
||||
return -1;
|
||||
}
|
||||
handle = tensorflow::make_safe(TFE_NewTensorHandle(t));
|
||||
}
|
||||
if (PyErr_Occurred()) return -1;
|
||||
if (handle == nullptr) {
|
||||
PyErr_SetString(PyExc_ValueError, "Error while creating an EagerTensor");
|
||||
return -1;
|
||||
}
|
||||
TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
|
||||
auto out_status = tensorflow::make_safe(TF_NewStatus());
|
||||
handle = tensorflow::make_safe(
|
||||
EagerCast(GetContext(context), handle.get(), handle_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), out_status.get()));
|
||||
if (TF_GetCode(out_status.get()) != TF_OK) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
tensorflow::strings::StrCat("Error while casting from DataType ",
|
||||
handle_dtype, " to ", desired_dtype, ". ",
|
||||
TF_Message(out_status.get()))
|
||||
.c_str());
|
||||
return -1;
|
||||
}
|
||||
handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
}
|
||||
|
||||
// Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
|
||||
// memory. We approximate the same behavior for eager execution - keeping
|
||||
// int32 tensors in host memory.
|
||||
//
|
||||
// We do so to preclude the need for callers into such kernels from having to
|
||||
// explicitly place the int32 tensors in host memory. For example, without
|
||||
// this, one needed:
|
||||
//
|
||||
// with tf.device('/gpu:0'):
|
||||
// ...// code here
|
||||
// with tf.device('/cpu:0'):
|
||||
// shape = tf.constant(...)
|
||||
// y = tf.random_uniform(shape)
|
||||
//
|
||||
// Without the CPU device block, tfe.ops.random_uniform would fail since the
|
||||
// kernel expects the shape in host memory.
|
||||
//
|
||||
// With this support, we simplify the code:
|
||||
//
|
||||
// with tf.device('/gpu:0'):
|
||||
// y = tf.random_uniform(...)
|
||||
//
|
||||
// The approximation is not exact there are GPU kernels which do not require
|
||||
// host memory for int32 tensors. This will lead to a discrepancy between
|
||||
// eager and graph execution.
|
||||
// TODO(ashankar): Fix this.
|
||||
if (handle_dtype != TF_INT32) {
|
||||
// Note that this is a shallow copy and will share the underlying buffer
|
||||
// if copying to the same device.
|
||||
handle = tensorflow::make_safe(CopyToDevice(handle.get(), context, device));
|
||||
if (handle == nullptr) return -1;
|
||||
}
|
||||
self->handle = handle.release();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// tp_dealloc for EagerTensor.
|
||||
void EagerTensor_dealloc(EagerTensor* self) {
|
||||
Py_DECREF(self->handle_data);
|
||||
Py_DECREF(self->keras_mask);
|
||||
TFE_DeleteTensorHandle(self->handle);
|
||||
self->handle = nullptr;
|
||||
PyObject* id = PyLong_FromLongLong(self->id);
|
||||
PyObject* func = PyObject_GetAttrString(reinterpret_cast<PyObject*>(self),
|
||||
"_delete_trace");
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
self = nullptr;
|
||||
// Note that we run `func` after calling `tp_free`. Otherwise calling that
|
||||
// function can potentially trigger garbage collection that observes `self`
|
||||
// in this half deleted state and crashes.
|
||||
// Note that `func` is a staticmethod and does not need `self` to be around
|
||||
// for running.
|
||||
// We clear (and later restore) any errors that have already been set. Else
|
||||
// these erorrs may appear randomly as part of the function execution.
|
||||
PyObject *a, *b, *c;
|
||||
PyErr_Fetch(&a, &b, &c);
|
||||
PyObject_CallFunctionObjArgs(func, id, nullptr);
|
||||
PyErr_Restore(a, b, c);
|
||||
Py_DECREF(func);
|
||||
Py_DECREF(id);
|
||||
}
|
||||
|
||||
// Getter for `_id`.
|
||||
static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
|
||||
return PyLong_FromLongLong(self->id);
|
||||
}
|
||||
|
||||
// Getter for `_datatype_enum`.
|
||||
static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
|
||||
return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
|
||||
}
|
||||
|
||||
// Getter for `_shape_tuple`.
|
||||
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||
auto handle = self->handle;
|
||||
int n = TFE_TensorHandleNumDims(handle);
|
||||
PyObject* shape = PyTuple_New(n);
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i));
|
||||
if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
||||
Py_DECREF(shape);
|
||||
if (dim != nullptr) Py_DECREF(dim);
|
||||
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
|
||||
Py_INCREF(self->handle_data);
|
||||
return self->handle_data;
|
||||
}
|
||||
|
||||
static int EagerTensor_settensor_handle(EagerTensor* self, PyObject* value,
|
||||
void* unused) {
|
||||
Py_DECREF(self->handle_data);
|
||||
Py_INCREF(value);
|
||||
self->handle_data = value;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject* EagerTensor_keras_mask(EagerTensor* self, void* unused) {
|
||||
Py_INCREF(self->keras_mask);
|
||||
return self->keras_mask;
|
||||
}
|
||||
|
||||
static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value,
|
||||
void* unused) {
|
||||
Py_DECREF(self->keras_mask);
|
||||
Py_INCREF(value);
|
||||
self->keras_mask = value;
|
||||
return 0;
|
||||
}
|
||||
// Function `_copy_to_device`.
|
||||
static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
||||
PyObject* kwds) {
|
||||
const char* kwlist[] = {"context", "device", nullptr};
|
||||
PyObject* ctx = nullptr;
|
||||
PyObject* dev = nullptr;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist),
|
||||
&ctx, &dev) ||
|
||||
!ctx || !dev) {
|
||||
return nullptr;
|
||||
}
|
||||
auto handle = CopyToDevice(self->handle, ctx, dev);
|
||||
return EagerTensorFromHandle(handle);
|
||||
}
|
||||
|
||||
// Function `_numpy`.
|
||||
// Convert an EagerTensor to a Python numpy.ndarray object.
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
// Note that if `self` is not on CPU, we raise an Exception.
|
||||
static PyObject* EagerTensor_numpy(EagerTensor* self) {
|
||||
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||
const tensorflow::Tensor* t =
|
||||
TFE_TensorHandleUnderlyingTensorInHostMemory(self->handle, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* ret = nullptr;
|
||||
auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
|
||||
if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
|
||||
Py_XDECREF(ret);
|
||||
return nullptr;
|
||||
} else {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
// Getter `device`.
|
||||
static PyObject* EagerTensor_device(EagerTensor* self) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle));
|
||||
#else
|
||||
return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle));
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyGetSetDef EagerTensor_getseters[] = {
|
||||
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
|
||||
const_cast<char*>("_id"), nullptr},
|
||||
{const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
|
||||
const_cast<char*>("device"), nullptr},
|
||||
{const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle,
|
||||
(setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"),
|
||||
nullptr},
|
||||
{const_cast<char*>("_keras_mask"), (getter)EagerTensor_keras_mask,
|
||||
(setter)EagerTensor_setkeras_mask, const_cast<char*>("_keras_mask"),
|
||||
nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
static PyMethodDef EagerTensor_methods[] = {
|
||||
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
|
||||
PyDoc_STR("_numpy")},
|
||||
{"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
|
||||
PyDoc_STR("_datatype_enum")},
|
||||
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
||||
PyDoc_STR("_shape_tuple")},
|
||||
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
|
||||
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
|
||||
{nullptr, nullptr},
|
||||
};
|
||||
|
||||
// Note that here we are trying to dynamically create a new class as a subclass
|
||||
// of a "HEAPTYPE" class that is itself created in python code and passed in at
|
||||
// runtime. This is fairly atypical and undocumented.
|
||||
//
|
||||
// We use the following strategy for this. Unfortunately, we have to use
|
||||
// different approaches for python2.x vs python3.x
|
||||
// For python2.x, we create the class as a static type and set its tp_base to
|
||||
// the passed in type. Unfortunately setting tp_flags to include
|
||||
// Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
|
||||
// initialization of the underlying PyHeapTypeObject and not doing that leads to
|
||||
// some random crashes especially during garbage collection.
|
||||
// python3.x explicitly disables a static subclass of a HEAPTYPE base class.
|
||||
// However it provides a new function, PyType_FromSpecWithBases, to create
|
||||
// types dynamically.
|
||||
|
||||
// Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
|
||||
PyTypeObject* EagerTensorType = nullptr;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
static PyType_Slot EagerTensor_Type_slots[] = {
|
||||
Py_tp_dealloc,
|
||||
reinterpret_cast<void*>(EagerTensor_dealloc),
|
||||
Py_tp_methods,
|
||||
reinterpret_cast<void*>(EagerTensor_methods),
|
||||
Py_tp_getset,
|
||||
reinterpret_cast<void*>(EagerTensor_getseters),
|
||||
Py_tp_init,
|
||||
reinterpret_cast<void*>(EagerTensor_init),
|
||||
0,
|
||||
nullptr,
|
||||
};
|
||||
|
||||
PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
|
||||
EagerTensor_Type_slots};
|
||||
#else
|
||||
// TODO(agarwal): support active_trace.
|
||||
static PyTypeObject _EagerTensorType = {
|
||||
// clang-format off
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
// clang-format on
|
||||
"EagerTensor", /* tp_name */
|
||||
sizeof(EagerTensor), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)EagerTensor_dealloc, /* tp_dealloc */
|
||||
nullptr, /* tp_print */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_compare */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
EagerTensor_methods, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
EagerTensor_getseters, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)EagerTensor_init, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
nullptr, /* tp_new */
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // extern "C"
|
||||
|
||||
bool EagerTensor_CheckExact(const PyObject* o) {
|
||||
return Py_TYPE(o) == EagerTensorType;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* EagerTensorHandle(const PyObject* o) {
|
||||
return reinterpret_cast<const EagerTensor*>(o)->handle;
|
||||
}
|
||||
|
||||
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
|
||||
if (handle == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
EagerTensor* t = reinterpret_cast<EagerTensor*>(
|
||||
EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None));
|
||||
if (t != nullptr) {
|
||||
t->id = get_uid();
|
||||
Py_INCREF(Py_None);
|
||||
t->handle_data = Py_None;
|
||||
Py_INCREF(Py_None);
|
||||
t->keras_mask = Py_None;
|
||||
t->handle = handle;
|
||||
}
|
||||
return reinterpret_cast<PyObject*>(t);
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||
if (!PyType_Check(base_class)) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting a class definition for `base_class` passed to ",
|
||||
"TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
// Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
|
||||
// EagerTensor to allow for the space usage of the base class.
|
||||
PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
|
||||
if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Unable to create subclass EagerTensor from base class ",
|
||||
Py_TYPE(base_class)->tp_name,
|
||||
". Need its size to be <= ", kMaxEagerTensorParentSize)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (base_class_type->tp_itemsize != 0) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Unable to create subclass EagerTensor from base class ",
|
||||
Py_TYPE(base_class)->tp_name,
|
||||
" which supports variable length instances.")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
Py_INCREF(base_class);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject* bases = PyTuple_New(1);
|
||||
PyTuple_SET_ITEM(bases, 0, base_class);
|
||||
EagerTensorType = reinterpret_cast<PyTypeObject*>(
|
||||
PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
|
||||
if (PyErr_Occurred()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (EagerTensorType == nullptr) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
|
||||
return nullptr;
|
||||
}
|
||||
#else
|
||||
_EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
|
||||
|
||||
if (PyType_Ready(&_EagerTensorType) < 0) {
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"Error while creating EagerTensor type.");
|
||||
return nullptr;
|
||||
}
|
||||
EagerTensorType = &_EagerTensorType;
|
||||
Py_INCREF(EagerTensorType);
|
||||
#endif
|
||||
// We disable instance based attribute lookup. Its not clear if these
|
||||
// dictionaries are correctly initialized in the first place.
|
||||
EagerTensorType->tp_dictoffset = 0;
|
||||
return reinterpret_cast<PyObject*>(EagerTensorType);
|
||||
}
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include <Python.h>
|
||||
|
||||
@ -44,38 +45,46 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
||||
PyObject* attrs, TFE_OutputTensorHandles* outputs,
|
||||
TF_Status* out_status);
|
||||
|
||||
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
||||
//
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
|
||||
//
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj);
|
||||
|
||||
// Convert a Python sequence value to a TFE_TensorHandle.
|
||||
//
|
||||
// The dtype of the result is determined by the type of values found
|
||||
// in *obj, *dtype is the desired type but it is only considered a
|
||||
// hint. *dtype should be an integer representing the desired DataType
|
||||
// enum value, or Py_None. Unlike TFE_Py_NumpyToTensorHandle, this
|
||||
// always makes a copy. Returns nullptr and raises an exception on
|
||||
// error.
|
||||
// TODO(josh11b): Cast to dtype automatically.
|
||||
TFE_TensorHandle* TFE_Py_SequenceToTensorHandle(PyObject* obj, PyObject* dtype);
|
||||
|
||||
// Registers e as the Exception class for handling not ok Status. Returns
|
||||
// Py_None if registration succeeds, else throws a TypeError and returns NULL.
|
||||
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
|
||||
|
||||
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using the
|
||||
// class registered via TFE_Py_RegisterExceptionClass) and returns -1.
|
||||
int TFE_Py_MaybeRaiseException(TF_Status* status);
|
||||
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
|
||||
// `exception` if not nullptr, else using the class registered via
|
||||
// TFE_Py_RegisterExceptionClass), and returns -1.
|
||||
int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception);
|
||||
|
||||
// Returns 0 if 'status' is ok. Otherwise, raises an exception (using
|
||||
// `exception` if not nullptr, else using the class registered via
|
||||
// TFE_Py_RegisterExceptionClass), and returns -1.
|
||||
int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
|
||||
PyObject* exception);
|
||||
|
||||
// Returns the string associated with the passed-in python object.
|
||||
char* TFE_GetPythonString(PyObject* o);
|
||||
|
||||
// Returns a unique id on each call.
|
||||
int64_t get_uid();
|
||||
|
||||
// Wraps the output of get_uid as a Python Long object. Ownership is passed to
|
||||
// the caller.
|
||||
PyObject* TFE_Py_UID();
|
||||
|
||||
// Deleter for Context objects, called from the Capsule that owns it.
|
||||
void TFE_DeleteContextCapsule(PyObject* context);
|
||||
|
||||
// Returns true if o is an instance of EagerTensor, but not a subclass. Else
|
||||
// returns false.
|
||||
bool EagerTensor_CheckExact(const PyObject* o);
|
||||
|
||||
// Helper function to construct a new EagerTensor from a TFE_TensorHandle.
|
||||
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
|
||||
|
||||
// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
|
||||
TFE_TensorHandle* EagerTensorHandle(const PyObject* o);
|
||||
|
||||
// Creates the `EagerTensor` class by subclassing `base_class` and returns the
|
||||
// newly created type, or nullptr on error.
|
||||
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
|
@ -13,16 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Must be included first.
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -320,6 +316,14 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Python subclass of Exception that is created on not ok Status.
|
||||
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
||||
|
||||
static tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
static tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
|
||||
|
||||
} // namespace
|
||||
|
||||
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
||||
@ -352,65 +356,6 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
||||
TFE_DeleteOp(op);
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* t =
|
||||
TFE_TensorHandleUnderlyingTensorInHostMemory(h, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyObject* ret = nullptr;
|
||||
auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
|
||||
if (!cppstatus.ok()) {
|
||||
TF_SetStatus(status, TF_Code(cppstatus.code()),
|
||||
cppstatus.error_message().c_str());
|
||||
}
|
||||
if (ret != nullptr) return ret;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Python subclass of Exception that is created on not ok Status.
|
||||
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
||||
|
||||
void PyRaiseException(TF_Code error_code, const char* msg) {
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
if (exception_class != nullptr) {
|
||||
PyErr_SetObject(exception_class, Py_BuildValue("si", msg, error_code));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_RuntimeError, msg);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
|
||||
tensorflow::Tensor t;
|
||||
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
||||
if (cppstatus.ok()) {
|
||||
return TFE_NewTensorHandle(t);
|
||||
} else {
|
||||
PyRaiseException(TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"failed to convert numpy ndarray to a Tensor (",
|
||||
cppstatus.error_message(), ")")
|
||||
.c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_Py_SequenceToTensorHandle(PyObject* obj,
|
||||
PyObject* dtype) {
|
||||
tensorflow::Tensor t;
|
||||
auto cppstatus = tensorflow::PySeqToTensor(obj, dtype, &t);
|
||||
if (cppstatus.ok()) {
|
||||
return TFE_NewTensorHandle(t);
|
||||
} else {
|
||||
PyRaiseException(TF_INVALID_ARGUMENT, cppstatus.error_message().c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
if (exception_class != nullptr) {
|
||||
@ -429,9 +374,39 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
|
||||
}
|
||||
}
|
||||
|
||||
int TFE_Py_MaybeRaiseException(TF_Status* status) {
|
||||
int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
|
||||
if (TF_GetCode(status) == TF_OK) return 0;
|
||||
PyRaiseException(TF_GetCode(status), TF_Message(status));
|
||||
const char* msg = TF_Message(status);
|
||||
if (exception == nullptr) {
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
if (exception_class != nullptr) {
|
||||
PyErr_SetObject(exception_class,
|
||||
Py_BuildValue("si", msg, TF_GetCode(status)));
|
||||
return -1;
|
||||
} else {
|
||||
exception = PyExc_RuntimeError;
|
||||
}
|
||||
}
|
||||
// May be update already set exception.
|
||||
PyErr_SetString(exception, msg);
|
||||
return -1;
|
||||
}
|
||||
|
||||
int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
|
||||
PyObject* exception) {
|
||||
if (status.ok()) return 0;
|
||||
const char* msg = status.error_message().c_str();
|
||||
if (exception == nullptr) {
|
||||
tensorflow::mutex_lock l(exception_class_mutex);
|
||||
if (exception_class != nullptr) {
|
||||
PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code()));
|
||||
return -1;
|
||||
} else {
|
||||
exception = PyExc_RuntimeError;
|
||||
}
|
||||
}
|
||||
// May be update already set exception.
|
||||
PyErr_SetString(exception, msg);
|
||||
return -1;
|
||||
}
|
||||
|
||||
@ -446,3 +421,18 @@ char* TFE_GetPythonString(PyObject* o) {
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t get_uid() {
|
||||
tensorflow::mutex_lock l(_uid_mutex);
|
||||
return _uid++;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
|
||||
|
||||
void TFE_DeleteContextCapsule(PyObject* context) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Context* ctx =
|
||||
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
|
||||
TFE_DeleteContext(ctx, status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
@ -135,9 +135,9 @@ class Tape(object):
|
||||
# adding an explicit stack if this ever gets out of hand
|
||||
self._delete_tensor_id(tensor_id)
|
||||
|
||||
def delete_trace(self, tensor):
|
||||
def delete_trace(self, tensor_id):
|
||||
"""Deletes any trace we have for this tensor."""
|
||||
self._delete_tensor_id(tid(tensor))
|
||||
self._delete_tensor_id(tensor_id)
|
||||
|
||||
def export(self):
|
||||
"""Exports the internal state of this tape.
|
||||
@ -237,10 +237,10 @@ def record_operation(op_type, output_tensors, input_tensors, side_outputs,
|
||||
backward_function)
|
||||
|
||||
|
||||
def delete_trace(tensor):
|
||||
def delete_trace(tensor_id):
|
||||
"""Deletes traces for this Tensor from all tapes in the stack."""
|
||||
for t in _tape_stack.stack:
|
||||
t.delete_trace(tensor)
|
||||
t.delete_trace(tensor_id)
|
||||
|
||||
|
||||
def top_tape_watched_tensors():
|
||||
|
@ -21,26 +21,90 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
def _create_tensor(value, device=None, dtype=None):
|
||||
ctx = context.context()
|
||||
if device is None:
|
||||
device = ctx.device_name
|
||||
if dtype is not None:
|
||||
dtype = dtype.as_datatype_enum
|
||||
try:
|
||||
return ops.EagerTensor(
|
||||
value, context=ctx._handle, device=device, dtype=dtype)
|
||||
except core._NotOkStatusException as e: # pylint: disable=protected-access
|
||||
raise core._status_to_exception(e.code, e.message)
|
||||
|
||||
|
||||
class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testScalarTensor(self):
|
||||
t = constant_op.constant(3)
|
||||
self.assertEqual(t.numpy(), constant_op.constant(np.array(3)).numpy())
|
||||
t = _create_tensor(3, dtype=dtypes.int32)
|
||||
self.assertEqual(t.numpy(), _create_tensor(np.array(3)).numpy())
|
||||
self.assertEqual(dtypes.int32, t.dtype)
|
||||
self.assertEqual(0, t.shape.ndims)
|
||||
self.assertAllEqual([], t.shape.as_list())
|
||||
self.assertIn("tf.Tensor", str(t))
|
||||
self.assertIn("tf.Tensor", repr(t))
|
||||
|
||||
def testBadConstructorArgs(self):
|
||||
ctx = context.context()
|
||||
handle = ctx._handle
|
||||
device = ctx.device_name
|
||||
# Missing context.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, r"Required argument 'context' \(pos 2\) not found"):
|
||||
ops.EagerTensor(1, device=device)
|
||||
# Missing device.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, r"Required argument 'device' \(pos 3\) not found"):
|
||||
ops.EagerTensor(1, context=handle)
|
||||
# Bad dtype type.
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
"Expecting a DataType value for dtype. Got"):
|
||||
ops.EagerTensor(1, context=handle, device=device, dtype="1")
|
||||
# Following errors happen when trying to copy to GPU.
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with ops.device("/device:GPU:0"):
|
||||
device = ctx.device_name
|
||||
# Bad context.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, "Expecting a PyCapsule encoded context handle. Got"):
|
||||
ops.EagerTensor(1.0, context=1, device=device)
|
||||
# Bad device.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, "Error parsing device argument to CopyToDevice"):
|
||||
ops.EagerTensor(1.0, context=handle, device=1)
|
||||
|
||||
def testNumpyValue(self):
|
||||
values = np.array([3.0])
|
||||
t = _create_tensor(values)
|
||||
self.assertAllEqual(values, t.numpy())
|
||||
|
||||
def testNumpyValueWithCast(self):
|
||||
values = np.array([3.0], dtype=np.float32)
|
||||
t = _create_tensor(values, dtype=dtypes.float64)
|
||||
self.assertAllEqual(values, t.numpy())
|
||||
ctx = context.context()
|
||||
# Bad dtype value.
|
||||
with self.assertRaisesRegexp(TypeError, "Invalid dtype argument value"):
|
||||
ops.EagerTensor(
|
||||
values, context=ctx._handle, device=ctx.device_name, dtype=12345)
|
||||
|
||||
def testNumpyOrderHandling(self):
|
||||
n = np.array([[1, 2], [3, 4]], order="F")
|
||||
t = _create_tensor(n)
|
||||
self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
|
||||
|
||||
def testTensorAndNumpyMatrix(self):
|
||||
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
|
||||
actual = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
actual = _create_tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
self.assertAllEqual(expected, actual.numpy())
|
||||
self.assertEqual(np.float32, actual.numpy().dtype)
|
||||
self.assertEqual(dtypes.float32, actual.dtype)
|
||||
@ -48,56 +112,50 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testFloatDowncast(self):
|
||||
# Unless explicitly specified, float64->float32
|
||||
t = constant_op.constant(3.0)
|
||||
t = _create_tensor(3.0)
|
||||
self.assertEqual(dtypes.float32, t.dtype)
|
||||
t = constant_op.constant(3.0, dtype=dtypes.float64)
|
||||
t = _create_tensor(3.0, dtype=dtypes.float64)
|
||||
self.assertEqual(dtypes.float64, t.dtype)
|
||||
|
||||
def testBool(self):
|
||||
t = constant_op.constant(False)
|
||||
t = _create_tensor(False)
|
||||
if t:
|
||||
self.assertFalse(True)
|
||||
|
||||
def testIntDowncast(self):
|
||||
t = constant_op.constant(3)
|
||||
t = _create_tensor(3)
|
||||
self.assertEqual(dtypes.int32, t.dtype)
|
||||
t = constant_op.constant(3, dtype=dtypes.int64)
|
||||
t = _create_tensor(3, dtype=dtypes.int64)
|
||||
self.assertEqual(dtypes.int64, t.dtype)
|
||||
t = constant_op.constant(2**33)
|
||||
t = _create_tensor(2**33)
|
||||
self.assertEqual(dtypes.int64, t.dtype)
|
||||
|
||||
def testTensorCreationFailure(self):
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(ValueError):
|
||||
# Should fail because the each row of the Python object has a different
|
||||
# number of columns.
|
||||
self.assertEqual(None, constant_op.constant([[1], [1, 2]]))
|
||||
|
||||
def testNumpyOrderHandling(self):
|
||||
n = np.array([[1, 2], [3, 4]], order="F")
|
||||
t = constant_op.constant(n)
|
||||
self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
|
||||
self.assertEqual(None, _create_tensor([[1], [1, 2]]))
|
||||
|
||||
def testMultiLineTensorStr(self):
|
||||
t = constant_op.constant(np.eye(3))
|
||||
t = _create_tensor(np.eye(3))
|
||||
tensor_str = str(t)
|
||||
self.assertIn("shape=%s, dtype=%s" % (t.shape, t.dtype.name), tensor_str)
|
||||
self.assertIn(str(t.numpy()), tensor_str)
|
||||
|
||||
def testMultiLineTensorRepr(self):
|
||||
t = constant_op.constant(np.eye(3))
|
||||
t = _create_tensor(np.eye(3))
|
||||
tensor_repr = repr(t)
|
||||
self.assertTrue(tensor_repr.startswith("<"))
|
||||
self.assertTrue(tensor_repr.endswith(">"))
|
||||
self.assertIn(
|
||||
"id=%d, shape=%s, dtype=%s, numpy=\n%r" % (
|
||||
t._id, t.shape, t.dtype.name, t.numpy()), tensor_repr)
|
||||
self.assertIn("id=%d, shape=%s, dtype=%s, numpy=\n%r" %
|
||||
(t._id, t.shape, t.dtype.name, t.numpy()), tensor_repr)
|
||||
|
||||
def testTensorStrReprObeyNumpyPrintOptions(self):
|
||||
orig_threshold = np.get_printoptions()["threshold"]
|
||||
orig_edgeitems = np.get_printoptions()["edgeitems"]
|
||||
np.set_printoptions(threshold=2, edgeitems=1)
|
||||
|
||||
t = constant_op.constant(np.arange(10, dtype=np.int32))
|
||||
t = _create_tensor(np.arange(10, dtype=np.int32))
|
||||
self.assertIn("[0 ..., 9]", str(t))
|
||||
self.assertIn("[0, ..., 9]", repr(t))
|
||||
|
||||
@ -105,30 +163,30 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
np.set_printoptions(threshold=orig_threshold, edgeitems=orig_edgeitems)
|
||||
|
||||
def testZeroDimTensorStr(self):
|
||||
t = constant_op.constant(42)
|
||||
t = _create_tensor(42)
|
||||
self.assertIn("42, shape=(), dtype=int32", str(t))
|
||||
|
||||
def testZeroDimTensorRepr(self):
|
||||
t = constant_op.constant(42)
|
||||
t = _create_tensor(42)
|
||||
self.assertTrue(repr(t).startswith("<"))
|
||||
self.assertTrue(repr(t).endswith(">"))
|
||||
self.assertIn("id=%d, shape=(), dtype=int32, numpy=42" % t._id, repr(t))
|
||||
|
||||
def testZeroSizeTensorStr(self):
|
||||
t = constant_op.constant(np.zeros(0, dtype=np.float32))
|
||||
t = _create_tensor(np.zeros(0, dtype=np.float32))
|
||||
self.assertIn("[], shape=(0,), dtype=float32", str(t))
|
||||
|
||||
def testZeroSizeTensorRepr(self):
|
||||
t = constant_op.constant(np.zeros(0, dtype=np.float32))
|
||||
t = _create_tensor(np.zeros(0, dtype=np.float32))
|
||||
self.assertTrue(repr(t).startswith("<"))
|
||||
self.assertTrue(repr(t).endswith(">"))
|
||||
self.assertIn(
|
||||
"id=%d, shape=(0,), dtype=float32, numpy=%r" % (t._id, t.numpy()),
|
||||
repr(t))
|
||||
self.assertIn("id=%d, shape=(0,), dtype=float32, numpy=%r" % (t._id,
|
||||
t.numpy()),
|
||||
repr(t))
|
||||
|
||||
def testStringTensor(self):
|
||||
t_np_orig = np.array([[b"a", b"ab"], [b"abc", b"abcd"]])
|
||||
t = constant_op.constant(t_np_orig)
|
||||
t = _create_tensor(t_np_orig)
|
||||
t_np = t.numpy()
|
||||
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
|
||||
|
||||
@ -137,9 +195,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
self.skipTest("No GPUs found")
|
||||
with ops.device("/device:GPU:0"):
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Can't copy Tensor with type string to device"):
|
||||
constant_op.constant("test string")
|
||||
RuntimeError, "Can't copy Tensor with type string to device"):
|
||||
_create_tensor("test string")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -84,26 +84,46 @@ def _eager_identity(tensor, ctx):
|
||||
return result
|
||||
|
||||
|
||||
def convert_to_eager_tensor(t, ctx, dtype=None):
|
||||
"""Converts the given `value` to an `EagerTensor`."""
|
||||
if isinstance(t, ops.EagerTensor):
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
|
||||
return t
|
||||
if isinstance(t, (int, float)):
|
||||
def convert_to_eager_tensor(value, ctx, dtype=None):
|
||||
"""Converts the given `value` to an `EagerTensor`.
|
||||
|
||||
Note that this function could return cached copies of created constants for
|
||||
performance reasons.
|
||||
|
||||
Args:
|
||||
value: value to convert to EagerTensor.
|
||||
ctx: value of context.context().
|
||||
dtype: optional desired dtype of the converted EagerTensor.
|
||||
|
||||
Returns:
|
||||
EagerTensor created from value.
|
||||
|
||||
Raises:
|
||||
TypeError: if `dtype` is not compatible with the type of t.
|
||||
"""
|
||||
if isinstance(value, ops.EagerTensor):
|
||||
if dtype is not None and value.dtype != dtype:
|
||||
raise TypeError("Expected tensor with type %r not %r" % (
|
||||
dtype, value.dtype))
|
||||
return value
|
||||
if dtype is not None:
|
||||
dtype = dtype.as_datatype_enum
|
||||
device = ctx.device_name
|
||||
handle = ctx._handle # pylint: disable=protected-access
|
||||
if isinstance(value, (int, float)):
|
||||
# Use a scalar cache. This will put each scalar of each type only once on
|
||||
# each device. Scalars don't use much device memory but copying scalars can
|
||||
# trigger memcpys which are slow.
|
||||
device = ctx.device_name
|
||||
cache_key = device, t, dtype, type(t)
|
||||
cache_key = device, value, dtype, type(value)
|
||||
scalar_cache = ctx.scalar_cache()
|
||||
tensor = scalar_cache.get(cache_key, None)
|
||||
if tensor is not None:
|
||||
return tensor
|
||||
value = ops.EagerTensor(t, ctx, dtype=dtype)
|
||||
scalar_cache[cache_key] = value
|
||||
return value
|
||||
return ops.EagerTensor(t, ctx, dtype=dtype)
|
||||
t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
|
||||
scalar_cache[cache_key] = t
|
||||
return t
|
||||
else:
|
||||
return ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
|
||||
|
||||
|
||||
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
||||
@ -152,13 +172,13 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
||||
A Constant Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError if shape is incorrectly specified or unsupported.
|
||||
TypeError: if shape is incorrectly specified or unsupported.
|
||||
"""
|
||||
ctx = context.context()
|
||||
if not ctx.in_graph_mode():
|
||||
if shape is None:
|
||||
return convert_to_eager_tensor(value, ctx, dtype)
|
||||
t = convert_to_eager_tensor(value, ctx, dtype)
|
||||
if shape is None:
|
||||
return t
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if shape == t.shape:
|
||||
return t
|
||||
|
@ -25,10 +25,9 @@ import re
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
@ -75,10 +74,6 @@ def tensor_id(tensor):
|
||||
return tensor._id # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _in_gpu_device(ctx):
|
||||
return "GPU" == ctx.device_spec.device_type
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def _null_contextmanager():
|
||||
yield
|
||||
@ -171,16 +166,9 @@ def register_dense_tensor_like_type(tensor_type):
|
||||
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
|
||||
|
||||
|
||||
_uid_counter = 0
|
||||
_uid_lock = threading.Lock()
|
||||
|
||||
|
||||
def uid():
|
||||
"""A unique (within this program execution) integer."""
|
||||
with _uid_lock:
|
||||
global _uid_counter
|
||||
_uid_counter += 1
|
||||
return _uid_counter
|
||||
return c_api.TFE_Py_UID()
|
||||
|
||||
|
||||
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
|
||||
@ -584,127 +572,18 @@ class Tensor(_TensorLike):
|
||||
return ret
|
||||
|
||||
|
||||
def _eager_cast(tensor_handle, src_type_enum, dest_type_enum, ctx):
|
||||
"""Cast tensor_handle from src_type_enum to dest_type_enum."""
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
out_handle, = c_api.TFE_Py_Execute(
|
||||
ctx._handle, b"/job:localhost/replica:0/task:0/device:CPU:0", b"Cast",
|
||||
[tensor_handle], (b"SrcT", src_type_enum, b"DstT", dest_type_enum), 1)
|
||||
except core._NotOkStatusException as e:
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
# pylint: enable=protected-access
|
||||
# TODO(josh11b): Should we support tracing or post_execution_callbacks here?
|
||||
return out_handle
|
||||
# TODO(agarwal): consider getting rid of this.
|
||||
class _EagerTensorBase(Tensor):
|
||||
"""Base class for EagerTensor."""
|
||||
|
||||
@staticmethod
|
||||
def _delete_trace(tid):
|
||||
"""Helper function to be called by __del__ of the subclass."""
|
||||
tape.delete_trace(tid)
|
||||
|
||||
# TODO(agarwal): rename to TensorHandle.
|
||||
class EagerTensor(Tensor):
|
||||
"""A TensorFlow Eager Tensor."""
|
||||
|
||||
def __init__(self, value, ctx, dtype=None): # pylint: disable=super-init-not-called
|
||||
"""Creates a Tensor object from a Python object or numpy array.
|
||||
|
||||
May share storage with the numpy array, in which case changes to the numpy
|
||||
object will reflect
|
||||
in the Tensor.
|
||||
|
||||
Arguments:
|
||||
value: A numpy.array or a Python object to create a Tensor for.
|
||||
ctx: The value of context.context().
|
||||
dtype: TensorFlow dtype for the returned Tensor. If None, one will be
|
||||
automatically selected.
|
||||
"""
|
||||
# TODO(ashankar): Evaluate if we can and perhaps share code with
|
||||
# tf.constant defined in
|
||||
# https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
|
||||
self._id = uid()
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(value, np.ndarray):
|
||||
if dtype is not None:
|
||||
npt = dtype.as_numpy_dtype
|
||||
if npt != value.dtype:
|
||||
value = value.astype(npt)
|
||||
try:
|
||||
value = np.asarray(value, order="C")
|
||||
self._handle = c_api.TFE_Py_NumpyToTensorHandle(value)
|
||||
except core._NotOkStatusException as e:
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
dtype = dtypes.as_dtype(c_api.TFE_TensorHandleDataType(self._handle))
|
||||
else:
|
||||
dtype_enum = None if dtype is None else dtype.as_datatype_enum
|
||||
try:
|
||||
self._handle = c_api.TFE_Py_SequenceToTensorHandle(value, dtype_enum)
|
||||
except core._NotOkStatusException as e:
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
|
||||
dtype_enum = c_api.TFE_TensorHandleDataType(self._handle)
|
||||
dtype_actual = dtypes.as_dtype(dtype_enum)
|
||||
if dtype is not None and dtype != dtype_actual:
|
||||
self._handle = _eager_cast(self._handle, dtype_enum,
|
||||
dtype.as_datatype_enum, ctx)
|
||||
else:
|
||||
dtype = dtype_actual
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
|
||||
# memory. This change approximates the same behavior for eager execution -
|
||||
# keeping int32 tensors in host memory.
|
||||
#
|
||||
# We do so to preclude the need for callers into such kernels from having to
|
||||
# explicitly place the int32 tensors in host memory. For example, prior to
|
||||
# this change one needed:
|
||||
#
|
||||
# with tf.device('/gpu:0'):
|
||||
# ... # code here
|
||||
# with tf.device('/cpu:0'):
|
||||
# shape = tf.constant(...)
|
||||
# y = tf.random_uniform(shape)
|
||||
#
|
||||
# Without the CPU device block tfe.ops.random_uniform would fail since the
|
||||
# kernel expects the shape in host memory.
|
||||
#
|
||||
# After this change, we simplify the code:
|
||||
#
|
||||
# with tf.device('/gpu:0'):
|
||||
# y = tf.random_uniform(...)
|
||||
#
|
||||
# The approximation is not exact there are GPU kernels which do not
|
||||
# require host memory for int32 tensors. This will lead to a discrepancy
|
||||
# between eager and graph execution.
|
||||
# TODO(ashankar): Fix this.
|
||||
if _in_gpu_device(ctx) and dtype != dtypes.int32:
|
||||
# pylint: disable=protected-access
|
||||
device_name = ctx.device_name
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._handle = c_api.TFE_TensorHandleCopyToDevice(
|
||||
self._handle, ctx._handle, device_name, status)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
self._dtype = dtype
|
||||
|
||||
# This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
||||
# be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
|
||||
# tensors, this will contain a serialized HandleData proto with shape
|
||||
# inference metadata about shapes and dtypes of resources accessible from
|
||||
# this handle.
|
||||
self._handle_data = None
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().record_tensor("MANUAL",
|
||||
tensor_id(self), self.device,
|
||||
self.shape.num_elements())
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
tape.delete_trace(self)
|
||||
if c_api is not None and c_api.TFE_DeleteTensorHandle is not None:
|
||||
c_api.TFE_DeleteTensorHandle(self._handle)
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().delete_tensor(tensor_id(self))
|
||||
except (AttributeError, TypeError):
|
||||
# Sometimes deletion during program shutdown throws exception as other
|
||||
# modules are no longer available.
|
||||
pass
|
||||
@property
|
||||
def dtype(self):
|
||||
return dtypes.as_dtype(self._datatype_enum())
|
||||
|
||||
def _numpy_text(self, is_repr=False):
|
||||
if self.dtype.is_numpy_compatible:
|
||||
@ -715,19 +594,6 @@ class EagerTensor(Tensor):
|
||||
numpy_text = "\n" + numpy_text
|
||||
return numpy_text
|
||||
|
||||
def __str__(self):
|
||||
return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(),
|
||||
self.shape,
|
||||
self.dtype.name)
|
||||
|
||||
def __repr__(self):
|
||||
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
|
||||
self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
|
||||
|
||||
@staticmethod
|
||||
def _override_operator(name, func):
|
||||
setattr(EagerTensor, name, func)
|
||||
|
||||
def numpy(self):
|
||||
"""Returns a numpy array with the same contents as the Tensor.
|
||||
|
||||
@ -742,69 +608,13 @@ class EagerTensor(Tensor):
|
||||
A numpy array that may share memory with the Tensor object. Any changes
|
||||
to one may be reflected in the other.
|
||||
"""
|
||||
# TODO(ashankar): This with status business seems expensive. Profile/avoid?
|
||||
cpu = self.as_cpu_tensor()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return c_api.TFE_Py_TensorHandleToNumpy(cpu._handle, status) # pylint: disable=protected-access
|
||||
return self.as_cpu_tensor()._numpy() # pylint: disable=protected-access
|
||||
|
||||
def _copy(self, ctx=None, device_name=None):
|
||||
"""Copies tensor to dest device."""
|
||||
# pylint: disable=protected-access
|
||||
# Creates a new tensor on the dest device.
|
||||
if ctx is None:
|
||||
ctx = context.context()
|
||||
if device_name is None:
|
||||
device_name = ctx.device_name
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
h = c_api.TFE_TensorHandleCopyToDevice(self._handle, ctx._handle,
|
||||
device_name, status)
|
||||
new_tensor = _tensor_from_handle(h)
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().record_tensor("COPY",
|
||||
tensor_id(new_tensor),
|
||||
new_tensor.device,
|
||||
new_tensor.shape.num_elements())
|
||||
def _numpy(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
# Record the copy on tape and define backprop copy as well.
|
||||
if not context.in_graph_mode():
|
||||
self_device = self.device
|
||||
def grad_fun(dresult):
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
grad_h = c_api.TFE_TensorHandleCopyToDevice(
|
||||
dresult._handle, ctx._handle, self_device, status)
|
||||
return _tensor_from_handle(grad_h)
|
||||
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
|
||||
return new_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _dup(self):
|
||||
return self._copy(device_name=self.device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return c_api.TFE_TensorHandleDeviceName(self._handle)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""The shape of this Tensor as a TensorShape object."""
|
||||
n = c_api.TFE_TensorHandleNumDims(self._handle)
|
||||
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
|
||||
# tensors (which have a valid, known shape). There were vague plans to
|
||||
# change this so that the Tensor class can also represent Tensors that have
|
||||
# not yet been computed.
|
||||
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
|
||||
# and also handle -1s returned by TFE_TensorHandleDim.
|
||||
assert n >= 0, "See comment in source code"
|
||||
return tensor_shape.TensorShape(
|
||||
[c_api.TFE_TensorHandleDim(self._handle, x) for x in range(n)])
|
||||
|
||||
def get_shape(self):
|
||||
"""Alias of Tensor.shape."""
|
||||
return self.shape
|
||||
def _datatype_enum(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _shape_tuple(self):
|
||||
"""The shape of this Tensor, as a tuple.
|
||||
@ -819,15 +629,62 @@ class EagerTensor(Tensor):
|
||||
Returns:
|
||||
tuple with the shape.
|
||||
"""
|
||||
n = c_api.TFE_TensorHandleNumDims(self._handle)
|
||||
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
|
||||
# tensors (which have a valid, known shape). There were vague plans to
|
||||
# change this so that the Tensor class can also represent Tensors that have
|
||||
# not yet been computed.
|
||||
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
|
||||
# and also handle -1s returned by TFE_TensorHandleDim.
|
||||
assert n >= 0, "See comment in source code"
|
||||
return tuple(c_api.TFE_TensorHandleDim(self._handle, x) for x in range(n))
|
||||
raise NotImplementedError()
|
||||
|
||||
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self):
|
||||
return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(),
|
||||
self.shape,
|
||||
self.dtype.name)
|
||||
|
||||
def __repr__(self):
|
||||
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
|
||||
self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
|
||||
|
||||
@staticmethod
|
||||
def _override_operator(name, func):
|
||||
setattr(_EagerTensorBase, name, func)
|
||||
|
||||
def _copy(self, ctx=None, device_name=None):
|
||||
"""Copies tensor to dest device."""
|
||||
# pylint: disable=protected-access
|
||||
# Creates a new tensor on the dest device.
|
||||
if ctx is None:
|
||||
ctx = context.context()
|
||||
if device_name is None:
|
||||
device_name = ctx.device_name
|
||||
# pylint: disable=protected-access
|
||||
try:
|
||||
new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
|
||||
except core._NotOkStatusException as e:
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().record_tensor("COPY",
|
||||
tensor_id(new_tensor),
|
||||
new_tensor.device,
|
||||
new_tensor.shape.num_elements())
|
||||
|
||||
# Record the copy on tape and define backprop copy as well.
|
||||
if not context.in_graph_mode():
|
||||
self_device = self.device
|
||||
def grad_fun(dresult):
|
||||
return dresult._copy(device_name=self_device)
|
||||
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
|
||||
return new_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _dup(self):
|
||||
return self._copy(device_name=self.device)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return tensor_shape.TensorShape(self._shape_tuple())
|
||||
|
||||
def get_shape(self):
|
||||
"""Alias of Tensor.shape."""
|
||||
return self.shape
|
||||
|
||||
def _shape_as_list(self):
|
||||
"""The shape of the tensor as a list."""
|
||||
@ -899,35 +756,9 @@ class EagerTensor(Tensor):
|
||||
raise NotImplementedError("eval not supported for Eager Tensors.")
|
||||
|
||||
|
||||
def _tensor_from_handle(handle):
|
||||
"""'Private' constructor for the Tensor object.
|
||||
|
||||
The existence of a 'handle' is an implementation detail that should be hidden
|
||||
from users of this module. Functions within this module do need to create a
|
||||
Tensor object from a handle though.
|
||||
|
||||
One option would be to have an __init__(self, handle) method on the
|
||||
Tensor class, but that would make the existence and use of a handle
|
||||
'public'.
|
||||
|
||||
Instead, this function avoids exposing a Tensor.__init__ that understands
|
||||
handles and yet allows functions within this module to create Tensor
|
||||
objects from a handle.
|
||||
|
||||
Arguments:
|
||||
handle: A valid TFE_TensorHandle object.
|
||||
|
||||
Returns:
|
||||
A Tensor object.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
t = EagerTensor.__new__(EagerTensor)
|
||||
t._id = uid()
|
||||
t._handle = handle
|
||||
t._dtype = dtypes.as_dtype(c_api.TFE_TensorHandleDataType(handle))
|
||||
t._handle_data = None
|
||||
return t
|
||||
# pylint: enable=protected-access
|
||||
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
|
||||
# registers it with the current module.
|
||||
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||
|
||||
|
||||
def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
|
||||
|
@ -298,9 +298,12 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testConvertToTensorEager(self):
|
||||
with context.eager_mode():
|
||||
t = ops.EagerTensor(1, context.context())
|
||||
t = constant_op.constant(1)
|
||||
self.assertTrue(isinstance(t, ops.EagerTensor))
|
||||
converted = ops.convert_to_tensor(t)
|
||||
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
||||
converted = ops.convert_to_tensor(1)
|
||||
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
||||
|
||||
def testConvertToTensorNestedTuple(self):
|
||||
with self.test_session():
|
||||
|
@ -103,8 +103,7 @@ class ConstantTest(test.TestCase):
|
||||
|
||||
# This integer is larger than all non-infinite numbers representable
|
||||
# by a double, raises an exception.
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"out-of-range integer"):
|
||||
with self.assertRaisesRegexp(ValueError, "out-of-range integer"):
|
||||
constant_op.constant(10**310, dtypes_lib.float64)
|
||||
|
||||
def testInt32(self):
|
||||
@ -126,8 +125,7 @@ class ConstantTest(test.TestCase):
|
||||
self.assertAllClose(np.array(orig), tf_ans.numpy())
|
||||
|
||||
# Out of range for an int64
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"out-of-range integer"):
|
||||
with self.assertRaisesRegexp(ValueError, "out-of-range integer"):
|
||||
constant_op.constant([2**72])
|
||||
|
||||
def testComplex64(self):
|
||||
@ -240,14 +238,13 @@ class ConstantTest(test.TestCase):
|
||||
self._testAll((x, 1))
|
||||
|
||||
def testSparseValuesRaiseErrors(self):
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"non-rectangular Python sequence"):
|
||||
with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
|
||||
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None):
|
||||
with self.assertRaisesRegexp(ValueError, None):
|
||||
constant_op.constant([[1, 2], [3]])
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None):
|
||||
with self.assertRaisesRegexp(ValueError, None):
|
||||
constant_op.constant([[1, 2], [3], [4, 5]])
|
||||
|
||||
|
||||
|
@ -128,7 +128,7 @@ class VariableScopeTest(test.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
variable_scope.get_variable("x4", initializer={})
|
||||
else:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaises(ValueError):
|
||||
variable_scope.get_variable("x4", initializer={})
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
|
@ -30,4 +30,11 @@ Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) {
|
||||
return Safe_TF_TensorPtr(tensor, TF_DeleteTensor);
|
||||
}
|
||||
|
||||
Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle) {
|
||||
return Safe_TFE_TensorHandlePtr(handle, TFE_DeleteTensorHandle);
|
||||
}
|
||||
|
||||
Safe_TF_StatusPtr make_safe(TF_Status* status) {
|
||||
return Safe_TF_StatusPtr(status, TF_DeleteStatus);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <Python.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -36,6 +37,21 @@ typedef void (*TF_DeleteTensor_type)(TF_Tensor*);
|
||||
typedef std::unique_ptr<TF_Tensor, TF_DeleteTensor_type> Safe_TF_TensorPtr;
|
||||
Safe_TF_TensorPtr make_safe(TF_Tensor* tensor);
|
||||
|
||||
// Safe containers for an owned TFE_TensorHandle. On destruction, the handle
|
||||
// will be deleted by TFE_DeleteTensorHandle. Note: can't use
|
||||
// decltype(&TFE_DeleteTensorHandle) due to SWIG
|
||||
typedef void (*TFE_DeleteTensorHandle_type)(TFE_TensorHandle*);
|
||||
typedef std::unique_ptr<TFE_TensorHandle, TFE_DeleteTensorHandle_type>
|
||||
Safe_TFE_TensorHandlePtr;
|
||||
Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle);
|
||||
|
||||
// Safe containers for an owned TF_Status. On destruction, the handle
|
||||
// will be deleted by TF_DeleteStatus. Note: can't use
|
||||
// decltype(&TF_DeleteStatus) due to SWIG
|
||||
typedef void (*TF_DeleteStatus_type)(TF_Status*);
|
||||
typedef std::unique_ptr<TF_Status, TF_DeleteStatus_type> Safe_TF_StatusPtr;
|
||||
Safe_TF_StatusPtr make_safe(TF_Status* status);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
|
||||
|
@ -15,24 +15,16 @@ limitations under the License.
|
||||
|
||||
%ignore "";
|
||||
|
||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||
%rename("%s") TFE_Py_NumpyToTensorHandle;
|
||||
%rename("%s") TFE_Py_SequenceToTensorHandle;
|
||||
%rename("%s") TFE_Py_AllEqualInt64;
|
||||
%rename("%s") TFE_NewContext;
|
||||
%rename("%s") TFE_DeleteContext;
|
||||
%rename("%s") TFE_ContextListDevices;
|
||||
%rename("%s") TFE_TensorHandleDataType;
|
||||
%rename("%s") TFE_TensorHandleNumDims;
|
||||
%rename("%s") TFE_DeleteTensorHandle;
|
||||
%rename("%s") TFE_Py_Execute;
|
||||
%rename("%s") TFE_ContextAddFunctionDef;
|
||||
%rename("%s") TFE_TensorHandleDim;
|
||||
%rename("%s") TFE_TensorHandleDeviceName;
|
||||
%rename("%s") TFE_TensorHandleCopyToDevice;
|
||||
%rename("%s") TFE_NewOp;
|
||||
%rename("%s") TFE_Py_TensorHandleToNumpy;
|
||||
%rename("%s") TFE_OpGetAttrType;
|
||||
%rename("%s") TFE_Py_InitEagerTensor;
|
||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||
%rename("%s") TFE_Py_Execute;
|
||||
%rename("%s") TFE_Py_UID;
|
||||
|
||||
|
||||
%{
|
||||
@ -79,6 +71,18 @@ limitations under the License.
|
||||
$1 = TFE_GetPythonString($input);
|
||||
}
|
||||
|
||||
%typemap(in) (TFE_Context*) {
|
||||
$1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr);
|
||||
|
||||
}
|
||||
%typemap(out) (TFE_Context*) {
|
||||
if ($1 == nullptr) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
$result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
|
||||
}
|
||||
}
|
||||
|
||||
%include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
||||
@ -95,15 +99,13 @@ limitations under the License.
|
||||
if (!elem) {
|
||||
SWIG_fail;
|
||||
}
|
||||
void* thp = nullptr;
|
||||
int res = SWIG_ConvertPtr(elem, &thp,
|
||||
$descriptor(TFE_TensorHandle*), 0 | 0);
|
||||
if (!SWIG_IsOK(res)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res),
|
||||
if (EagerTensor_CheckExact(elem)) {
|
||||
(*$1)[i] = EagerTensorHandle(elem);
|
||||
} else {
|
||||
SWIG_exception_fail(SWIG_TypeError,
|
||||
"provided list of inputs contains objects other "
|
||||
"than 'TFE_TensorHandle*'");
|
||||
"than 'EagerTensor'");
|
||||
}
|
||||
(*$1)[i] = reinterpret_cast<TFE_TensorHandle*>(thp);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -129,45 +131,32 @@ limitations under the License.
|
||||
}
|
||||
|
||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
|
||||
if (TFE_Py_MaybeRaiseException($2)) {
|
||||
if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
int num_outputs = $1->size();
|
||||
$result = PyList_New(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
PyList_SetItem($result, i, SWIG_NewPointerObj(SWIG_as_voidptr($1->at(i)),
|
||||
$descriptor(TFE_TensorHandle*),
|
||||
0 | 0));
|
||||
PyObject *output;
|
||||
output = EagerTensorFromHandle($1->at(i));
|
||||
PyList_SetItem($result, i, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that we need to use a typemap for TFE_TensorHandle* so that we can call
|
||||
// SWIG_fail in case the value is nullptr. Otherwise SWIG will wrap the
|
||||
// nullptr and return it to python as an opaque object, and python does not
|
||||
// know that it needs to check if an Exception has been raised.
|
||||
// TODO(agarwal): check if we can get rid of this typemap.
|
||||
%typemap(out) (TFE_TensorHandle*) {
|
||||
if ($1 == nullptr) {
|
||||
SWIG_fail;
|
||||
} else {
|
||||
$result = SWIG_NewPointerObj(SWIG_as_voidptr($1),
|
||||
$descriptor(TFE_TensorHandle*), 0 | 0);
|
||||
}
|
||||
}
|
||||
|
||||
%include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
||||
|
||||
// Clear all typemaps127
|
||||
// Clear all typemaps.
|
||||
%typemap(out) TF_DataType;
|
||||
%typemap(out) int64_t;
|
||||
%typemap(out) TF_AttrType;
|
||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||
%typemap(argout) unsigned char* is_list;
|
||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp);
|
||||
%typemap(in) (TFE_Context*);
|
||||
%typemap(out) (TFE_Context*);
|
||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
|
||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||
%typemap(freearg) (TF_Status* out_status);
|
||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
|
||||
%typemap(out) (TFE_TensorHandle*);
|
||||
|
Loading…
x
Reference in New Issue
Block a user