Move EagerTensor from python to C.
PiperOrigin-RevId: 170617321
This commit is contained in:
parent
da8349412f
commit
ff18944249
@ -842,6 +842,7 @@ set (pywrap_tensorflow_internal_src
|
|||||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
|
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
|
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h"
|
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tensor.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc"
|
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h"
|
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc"
|
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc"
|
||||||
|
@ -266,6 +266,7 @@ cc_library(
|
|||||||
hdrs = ["lib/core/safe_ptr.h"],
|
hdrs = ["lib/core/safe_ptr.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
"//util/python:python_headers",
|
"//util/python:python_headers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "pywrap_tfe_lib",
|
name = "pywrap_tfe_lib",
|
||||||
srcs = ["pywrap_tfe_src.cc"],
|
srcs = [
|
||||||
|
"pywrap_tensor.cc",
|
||||||
|
"pywrap_tfe_src.cc",
|
||||||
|
],
|
||||||
hdrs = ["pywrap_tfe.h"],
|
hdrs = ["pywrap_tfe.h"],
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -14,8 +17,10 @@ cc_library(
|
|||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/python:ndarray_tensor",
|
"//tensorflow/python:ndarray_tensor",
|
||||||
|
"//tensorflow/python:ndarray_tensor_bridge",
|
||||||
"//tensorflow/python:numpy_lib",
|
"//tensorflow/python:numpy_lib",
|
||||||
"//tensorflow/python:py_seq_tensor",
|
"//tensorflow/python:py_seq_tensor",
|
||||||
|
"//tensorflow/python:safe_ptr",
|
||||||
"//util/python:python_headers",
|
"//util/python:python_headers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.python.eager import backprop # pylint: disable=unused-import
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
@ -61,18 +62,41 @@ def benchmark_create_tensor(n):
|
|||||||
def label(s):
|
def label(s):
|
||||||
return "{:20s}".format(s)
|
return "{:20s}".format(s)
|
||||||
|
|
||||||
with timer(label("np.array([[3]])"), iters=n) as iters:
|
with timer(label("np.array([[3.0]])"), iters=n) as iters:
|
||||||
for _ in iters:
|
for _ in iters:
|
||||||
np.array([[3]])
|
np.array([[3.0]])
|
||||||
|
|
||||||
with timer(label("Tensor([[3]])"), iters=n) as iters:
|
|
||||||
for _ in iters:
|
|
||||||
ops.EagerTensor([[3]], context.context())
|
|
||||||
|
|
||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
with timer(label("Tensor([[3]], ctx)"), iters=n) as iters:
|
handle = ctx._handle
|
||||||
|
device = ctx.device_name
|
||||||
|
# May be warmup GPU.
|
||||||
|
ops.EagerTensor([[3.0]], context=handle, device=device)
|
||||||
|
|
||||||
|
# float32
|
||||||
|
dtype = dtypes.float32.as_datatype_enum
|
||||||
|
three = [[3.0]]
|
||||||
|
with timer(label("EagerTensor([[3.0]])"), iters=n) as iters:
|
||||||
for _ in iters:
|
for _ in iters:
|
||||||
ops.EagerTensor([[3]], ctx)
|
ops.EagerTensor(three, context=handle, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
np_3 = np.array([[3.0]], dtype=np.float32)
|
||||||
|
with timer(label("EagerTensor(np.array([[3.0]]))"), iters=n) as iters:
|
||||||
|
for _ in iters:
|
||||||
|
ops.EagerTensor(np_3, context=handle, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# int32.
|
||||||
|
# This is interesting since int32 will be kept on host memory for the GPU
|
||||||
|
# case.
|
||||||
|
dtype = dtypes.int32.as_datatype_enum
|
||||||
|
three = [[3]]
|
||||||
|
with timer(label("EagerTensor([[3]])"), iters=n) as iters:
|
||||||
|
for _ in iters:
|
||||||
|
ops.EagerTensor(three, context=handle, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
np_3 = np.array([[3]], dtype=np.int32)
|
||||||
|
with timer(label("EagerTensor(np.array([[3]]))"), iters=n) as iters:
|
||||||
|
for _ in iters:
|
||||||
|
ops.EagerTensor(np_3, context=handle, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_matmul(shape, n, use_gpu=False):
|
def benchmark_matmul(shape, n, use_gpu=False):
|
||||||
@ -103,17 +127,16 @@ def benchmark_matmul(shape, n, use_gpu=False):
|
|||||||
for _ in iters:
|
for _ in iters:
|
||||||
gen_math_ops._mat_mul(m, m, transpose_b=transpose_b)
|
gen_math_ops._mat_mul(m, m, transpose_b=transpose_b)
|
||||||
|
|
||||||
|
inputs = [m, m]
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
input_handles = [m._handle, m._handle]
|
|
||||||
ctx_handle = context.context()._handle
|
ctx_handle = context.context()._handle
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
attrs = ("transpose_a", False, "transpose_b", transpose_b, "T",
|
attrs = ("transpose_a", False, "transpose_b", transpose_b, "T",
|
||||||
m.dtype.as_datatype_enum)
|
m.dtype.as_datatype_enum)
|
||||||
with timer(label("TFE_Py_Execute"), iters=n) as iters:
|
with timer(label("TFE_Py_Execute"), iters=n) as iters:
|
||||||
for _ in iters:
|
for _ in iters:
|
||||||
pywrap_tensorflow.TFE_DeleteTensorHandle(
|
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul",
|
||||||
pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul",
|
inputs, attrs, 1)
|
||||||
input_handles, attrs, 1)[0])
|
|
||||||
|
|
||||||
f = function.defun(math_ops.matmul)
|
f = function.defun(math_ops.matmul)
|
||||||
with timer(label("defun(tf.matmul)"), iters=n) as iters:
|
with timer(label("defun(tf.matmul)"), iters=n) as iters:
|
||||||
@ -133,6 +156,8 @@ class BenchmarksTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
if context.context().num_gpus() > 0:
|
if context.context().num_gpus() > 0:
|
||||||
print("---- RUNNING ON GPU NOW ----")
|
print("---- RUNNING ON GPU NOW ----")
|
||||||
|
with context.device("/device:GPU:0"):
|
||||||
|
benchmark_create_tensor(FLAGS.iters or 30000)
|
||||||
benchmark_matmul([2, 2], FLAGS.iters or 30000, use_gpu=True)
|
benchmark_matmul([2, 2], FLAGS.iters or 30000, use_gpu=True)
|
||||||
benchmark_matmul([100, 28 * 28], FLAGS.iters or 1000, use_gpu=True)
|
benchmark_matmul([100, 28 * 28], FLAGS.iters or 1000, use_gpu=True)
|
||||||
|
|
||||||
|
@ -121,16 +121,6 @@ class Context(object):
|
|||||||
else:
|
else:
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
if self._context_handle is not None:
|
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
|
||||||
pywrap_tensorflow.TFE_DeleteContext(self._context_handle, status)
|
|
||||||
except (AttributeError, TypeError):
|
|
||||||
# Sometimes deletion during program shutdown throws exception as other
|
|
||||||
# modules are no longer available.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self._context_handle is None:
|
if self._context_handle is None:
|
||||||
return "Eager TensorFlow Context. Devices currently uninitialized."
|
return "Eager TensorFlow Context. Devices currently uninitialized."
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -138,7 +139,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
x = x.as_cpu_tensor()
|
x = x.as_cpu_tensor()
|
||||||
|
|
||||||
# Invalid device
|
# Invalid device
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(RuntimeError):
|
||||||
x.as_gpu_tensor(context.context().num_gpus() + 1)
|
x.as_gpu_tensor(context.context().num_gpus() + 1)
|
||||||
|
|
||||||
def testNumpyForceCPU(self):
|
def testNumpyForceCPU(self):
|
||||||
@ -153,7 +154,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
ta = constant_op.constant([[1, 2], [3, 4]])
|
ta = constant_op.constant([[1, 2], [3, 4]])
|
||||||
tb = ta.as_cpu_tensor()
|
tb = ta.as_cpu_tensor()
|
||||||
|
|
||||||
self.assertNotEqual(ta._handle, tb._handle)
|
self.assertNotEqual(id(ta), id(tb))
|
||||||
self.assertAllEqual(ta.numpy(), tb.numpy())
|
self.assertAllEqual(ta.numpy(), tb.numpy())
|
||||||
|
|
||||||
def testRegisterExceptionClass(self):
|
def testRegisterExceptionClass(self):
|
||||||
|
@ -53,32 +53,27 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
|||||||
Raises:
|
Raises:
|
||||||
An exception on error.
|
An exception on error.
|
||||||
"""
|
"""
|
||||||
# TODO(apassos) move this to convert_to_tensor
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
input_handles = [c._handle for c in inputs]
|
|
||||||
device_name = ctx.device_name
|
device_name = ctx.device_name
|
||||||
|
# pylint: disable=protected-access
|
||||||
try:
|
try:
|
||||||
outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
|
tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
|
||||||
op_name, input_handles, attrs,
|
op_name, inputs, attrs,
|
||||||
num_outputs)
|
num_outputs)
|
||||||
except core._NotOkStatusException as e:
|
except core._NotOkStatusException as e:
|
||||||
if name is not None:
|
if name is not None:
|
||||||
message = e.message + " name: " + name
|
message = e.message + " name: " + name
|
||||||
else:
|
else:
|
||||||
message = e.message
|
message = e.message
|
||||||
six.raise_from(core._status_to_exception(e.code, message), None)
|
six.raise_from(core._status_to_exception(e.code, message), None)
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
tensors = [ops._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
|
|
||||||
# TODO(alive, cais): Use the execution callback mechanism.
|
# TODO(alive, cais): Use the execution callback mechanism.
|
||||||
if core.active_trace() is not None:
|
if core.active_trace() is not None:
|
||||||
for t in tensors:
|
for t in tensors:
|
||||||
# pylint: disable=protected-access
|
|
||||||
core.active_trace().record_tensor(op_name,
|
core.active_trace().record_tensor(op_name,
|
||||||
ops.tensor_id(t),
|
ops.tensor_id(t),
|
||||||
t.device,
|
t.device,
|
||||||
t.shape.num_elements())
|
t.shape.num_elements())
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
# TODO(cais): Optimize this, perhaps by replacing this execute function with
|
# TODO(cais): Optimize this, perhaps by replacing this execute function with
|
||||||
# a different one when there are execution callback(s).
|
# a different one when there are execution callback(s).
|
||||||
|
@ -162,7 +162,7 @@ def inf_nan_callback(op_type,
|
|||||||
# TODO(cais): Consider moving this into execute.py.
|
# TODO(cais): Consider moving this into execute.py.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
pywrap_tensorflow.TFE_Py_Execute(
|
pywrap_tensorflow.TFE_Py_Execute(
|
||||||
ctx._handle, output.device, "CheckNumerics", [output._handle],
|
ctx._handle, output.device, "CheckNumerics", [output],
|
||||||
check_numerics_op_attrs, 1)
|
check_numerics_op_attrs, 1)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
except core._NotOkStatusException: # pylint: disable=protected-access
|
except core._NotOkStatusException: # pylint: disable=protected-access
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.ops import random_ops
|
|||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
|
||||||
|
|
||||||
class TargetTest(test_util.TensorFlowTestCase):
|
class OpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testExecuteBasic(self):
|
def testExecuteBasic(self):
|
||||||
three = constant_op.constant(3)
|
three = constant_op.constant(3)
|
||||||
|
646
tensorflow/python/eager/pywrap_tensor.cc
Normal file
646
tensorflow/python/eager/pywrap_tensor.cc
Normal file
@ -0,0 +1,646 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||||
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
|
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
|
|
||||||
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TFE_Context* GetContext(PyObject* ctx) {
|
||||||
|
TFE_Context* context =
|
||||||
|
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
|
||||||
|
if (context == nullptr) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Expecting a PyCapsule encoded context handle. Got ",
|
||||||
|
Py_TYPE(ctx)->tp_name)
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
|
||||||
|
// The two may share underlying storage so changes to one may reflect in the
|
||||||
|
// other.
|
||||||
|
TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) {
|
||||||
|
tensorflow::Tensor t;
|
||||||
|
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
||||||
|
if (cppstatus.ok()) {
|
||||||
|
return TFE_NewTensorHandle(t);
|
||||||
|
} else {
|
||||||
|
PyErr_SetString(PyExc_ValueError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Failed to convert numpy ndarray to a Tensor (",
|
||||||
|
cppstatus.error_message(), ").")
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Casts data referred to by `handle` from type `src_type_enum` to type
|
||||||
|
// `dst_type_enum`.
|
||||||
|
TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||||
|
TF_DataType src_type_enum,
|
||||||
|
TF_DataType dst_type_enum, TF_Status* out_status) {
|
||||||
|
if (ctx == nullptr) return nullptr;
|
||||||
|
const char* op_name = "Cast";
|
||||||
|
const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
|
||||||
|
#define RETURN_ERROR \
|
||||||
|
{ \
|
||||||
|
TFE_DeleteOp(op); \
|
||||||
|
return nullptr; \
|
||||||
|
}
|
||||||
|
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||||
|
TFE_OpSetDevice(op, device_name, out_status);
|
||||||
|
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||||
|
TFE_OpAddInput(op, handle, out_status);
|
||||||
|
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||||
|
TFE_OpSetAttrType(op, "SrcT", src_type_enum);
|
||||||
|
TFE_OpSetAttrType(op, "DstT", dst_type_enum);
|
||||||
|
TFE_TensorHandle* output = nullptr;
|
||||||
|
int num_outputs = 1;
|
||||||
|
TFE_Execute(op, &output, &num_outputs, out_status);
|
||||||
|
if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
|
||||||
|
output == nullptr) {
|
||||||
|
if (output != nullptr) {
|
||||||
|
TFE_DeleteTensorHandle(output);
|
||||||
|
}
|
||||||
|
RETURN_ERROR
|
||||||
|
}
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
return output;
|
||||||
|
#undef RETURN_ERROR
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
|
||||||
|
PyObject* dev) {
|
||||||
|
const char* device = "";
|
||||||
|
if (dev != nullptr && dev != Py_None) {
|
||||||
|
device = PyBytes_AsString(dev);
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
if (device == nullptr) {
|
||||||
|
PyErr_Clear();
|
||||||
|
device = PyUnicode_AsUTF8(dev);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (device == nullptr) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
"Error parsing device argument to CopyToDevice");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TFE_Context* context = GetContext(ctx);
|
||||||
|
if (context == nullptr) { // PyErr already set by GetContext
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
TFE_TensorHandle* new_handle =
|
||||||
|
TFE_TensorHandleCopyToDevice(handle, context, device, status.get());
|
||||||
|
if (TF_GetCode(status.get()) != TF_OK) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_RuntimeError,
|
||||||
|
tensorflow::strings::StrCat("Error copying tensor to device: ", device,
|
||||||
|
". ", TF_Message(status.get()))
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return new_handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to convert `v` to an int and store it in `*out`. Returns true
|
||||||
|
// on success, false otherwise.
|
||||||
|
// Note that we assume that v is a python int (not long) representing a
|
||||||
|
// TF_DataType value.
|
||||||
|
bool PyIntToDataType(PyObject* v, int* out) {
|
||||||
|
#if PY_MAJOR_VERSION < 3
|
||||||
|
if (PyInt_Check(v)) {
|
||||||
|
*out = PyInt_AS_LONG(v);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (PyLong_Check(v)) {
|
||||||
|
*out = PyLong_AsLong(v);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create a python integer from TF_DataType.
|
||||||
|
PyObject* PyIntFromDataType(TF_DataType l) {
|
||||||
|
#if PY_MAJOR_VERSION < 3
|
||||||
|
return PyInt_FromLong(l);
|
||||||
|
#else
|
||||||
|
return PyLong_FromLong(l);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
static const int kMaxEagerTensorParentSize = 32;
|
||||||
|
|
||||||
|
// TODO(agarwal): store context handle in EagerTensor.
|
||||||
|
typedef struct EagerTensor {
|
||||||
|
PyObject_HEAD;
|
||||||
|
// Note that we leave kMaxEagerTensorParentSize bytes here for use by the
|
||||||
|
// parent class. The parent class is set at runtime, so we don't know the
|
||||||
|
// exact size at compile time.
|
||||||
|
char unused[kMaxEagerTensorParentSize];
|
||||||
|
TFE_TensorHandle* handle;
|
||||||
|
int64_t id;
|
||||||
|
// This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
||||||
|
// be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
|
||||||
|
// tensors, this will contain a serialized HandleData proto with shape
|
||||||
|
// inference metadata about shapes and dtypes of resources accessible from
|
||||||
|
// this handle.
|
||||||
|
// Note that we assume that handle_data cannot participate in reference
|
||||||
|
// cycles, and hence don't provide GC support for it.
|
||||||
|
PyObject* handle_data;
|
||||||
|
|
||||||
|
// This stores `_keras_mask` object and is set by Tensorflow layers.
|
||||||
|
PyObject* keras_mask;
|
||||||
|
} EagerTensor;
|
||||||
|
|
||||||
|
// tp_init for EagerTensor.
|
||||||
|
int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
||||||
|
self->id = get_uid();
|
||||||
|
self->handle = nullptr;
|
||||||
|
Py_INCREF(Py_None);
|
||||||
|
self->handle_data = Py_None;
|
||||||
|
Py_INCREF(Py_None);
|
||||||
|
self->keras_mask = Py_None;
|
||||||
|
PyObject* value;
|
||||||
|
PyObject* context = nullptr;
|
||||||
|
PyObject* device = nullptr;
|
||||||
|
PyObject* dtype = Py_None;
|
||||||
|
const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
|
||||||
|
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
|
||||||
|
const_cast<char**>(kwlist), &value, &context,
|
||||||
|
&device, &dtype)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
// Extract dtype
|
||||||
|
int desired_dtype = -1;
|
||||||
|
if (dtype != Py_None) {
|
||||||
|
if (!PyIntToDataType(dtype, &desired_dtype)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Expecting a DataType value for dtype. Got ",
|
||||||
|
Py_TYPE(dtype)->tp_name)
|
||||||
|
.c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensorflow::Safe_TFE_TensorHandlePtr handle =
|
||||||
|
tensorflow::make_safe(static_cast<TFE_TensorHandle*>(nullptr));
|
||||||
|
PyErr_Clear();
|
||||||
|
if (PyArray_Check(value)) {
|
||||||
|
int desired_np_dtype = -1;
|
||||||
|
if (desired_dtype >= 0) {
|
||||||
|
if (!tensorflow::TF_DataType_to_PyArray_TYPE(
|
||||||
|
static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
|
||||||
|
.ok()) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Invalid dtype argument value ", desired_dtype)
|
||||||
|
.c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
||||||
|
int current_np_dtype = PyArray_TYPE(array);
|
||||||
|
auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
|
||||||
|
if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
|
||||||
|
!PyArray_ISCARRAY(array)) {
|
||||||
|
int new_dtype =
|
||||||
|
desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
|
||||||
|
safe_value = tensorflow::make_safe(
|
||||||
|
PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
|
||||||
|
NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
|
||||||
|
if (PyErr_Occurred()) return -1;
|
||||||
|
if (safe_value == nullptr) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
value = safe_value.get();
|
||||||
|
}
|
||||||
|
handle = tensorflow::make_safe(NumpyToTensorHandle(value));
|
||||||
|
} else {
|
||||||
|
tensorflow::Tensor t;
|
||||||
|
// TODO(josh11b): Have PySeqToTensor set python errors instead of
|
||||||
|
// returning Status.
|
||||||
|
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
|
||||||
|
if (!cppstatus.ok()) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
handle = tensorflow::make_safe(TFE_NewTensorHandle(t));
|
||||||
|
}
|
||||||
|
if (PyErr_Occurred()) return -1;
|
||||||
|
if (handle == nullptr) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Error while creating an EagerTensor");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||||
|
if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
|
||||||
|
auto out_status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
handle = tensorflow::make_safe(
|
||||||
|
EagerCast(GetContext(context), handle.get(), handle_dtype,
|
||||||
|
static_cast<TF_DataType>(desired_dtype), out_status.get()));
|
||||||
|
if (TF_GetCode(out_status.get()) != TF_OK) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_ValueError,
|
||||||
|
tensorflow::strings::StrCat("Error while casting from DataType ",
|
||||||
|
handle_dtype, " to ", desired_dtype, ". ",
|
||||||
|
TF_Message(out_status.get()))
|
||||||
|
.c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
|
||||||
|
// memory. We approximate the same behavior for eager execution - keeping
|
||||||
|
// int32 tensors in host memory.
|
||||||
|
//
|
||||||
|
// We do so to preclude the need for callers into such kernels from having to
|
||||||
|
// explicitly place the int32 tensors in host memory. For example, without
|
||||||
|
// this, one needed:
|
||||||
|
//
|
||||||
|
// with tf.device('/gpu:0'):
|
||||||
|
// ...// code here
|
||||||
|
// with tf.device('/cpu:0'):
|
||||||
|
// shape = tf.constant(...)
|
||||||
|
// y = tf.random_uniform(shape)
|
||||||
|
//
|
||||||
|
// Without the CPU device block, tfe.ops.random_uniform would fail since the
|
||||||
|
// kernel expects the shape in host memory.
|
||||||
|
//
|
||||||
|
// With this support, we simplify the code:
|
||||||
|
//
|
||||||
|
// with tf.device('/gpu:0'):
|
||||||
|
// y = tf.random_uniform(...)
|
||||||
|
//
|
||||||
|
// The approximation is not exact there are GPU kernels which do not require
|
||||||
|
// host memory for int32 tensors. This will lead to a discrepancy between
|
||||||
|
// eager and graph execution.
|
||||||
|
// TODO(ashankar): Fix this.
|
||||||
|
if (handle_dtype != TF_INT32) {
|
||||||
|
// Note that this is a shallow copy and will share the underlying buffer
|
||||||
|
// if copying to the same device.
|
||||||
|
handle = tensorflow::make_safe(CopyToDevice(handle.get(), context, device));
|
||||||
|
if (handle == nullptr) return -1;
|
||||||
|
}
|
||||||
|
self->handle = handle.release();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// tp_dealloc for EagerTensor.
|
||||||
|
void EagerTensor_dealloc(EagerTensor* self) {
|
||||||
|
Py_DECREF(self->handle_data);
|
||||||
|
Py_DECREF(self->keras_mask);
|
||||||
|
TFE_DeleteTensorHandle(self->handle);
|
||||||
|
self->handle = nullptr;
|
||||||
|
PyObject* id = PyLong_FromLongLong(self->id);
|
||||||
|
PyObject* func = PyObject_GetAttrString(reinterpret_cast<PyObject*>(self),
|
||||||
|
"_delete_trace");
|
||||||
|
Py_TYPE(self)->tp_free(self);
|
||||||
|
self = nullptr;
|
||||||
|
// Note that we run `func` after calling `tp_free`. Otherwise calling that
|
||||||
|
// function can potentially trigger garbage collection that observes `self`
|
||||||
|
// in this half deleted state and crashes.
|
||||||
|
// Note that `func` is a staticmethod and does not need `self` to be around
|
||||||
|
// for running.
|
||||||
|
// We clear (and later restore) any errors that have already been set. Else
|
||||||
|
// these erorrs may appear randomly as part of the function execution.
|
||||||
|
PyObject *a, *b, *c;
|
||||||
|
PyErr_Fetch(&a, &b, &c);
|
||||||
|
PyObject_CallFunctionObjArgs(func, id, nullptr);
|
||||||
|
PyErr_Restore(a, b, c);
|
||||||
|
Py_DECREF(func);
|
||||||
|
Py_DECREF(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter for `_id`.
|
||||||
|
static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
|
||||||
|
return PyLong_FromLongLong(self->id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter for `_datatype_enum`.
|
||||||
|
static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
|
||||||
|
return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter for `_shape_tuple`.
|
||||||
|
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||||
|
auto handle = self->handle;
|
||||||
|
int n = TFE_TensorHandleNumDims(handle);
|
||||||
|
PyObject* shape = PyTuple_New(n);
|
||||||
|
if (PyErr_Occurred()) return nullptr;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i));
|
||||||
|
if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
|
||||||
|
Py_DECREF(shape);
|
||||||
|
if (dim != nullptr) Py_DECREF(dim);
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
|
||||||
|
Py_INCREF(self->handle_data);
|
||||||
|
return self->handle_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int EagerTensor_settensor_handle(EagerTensor* self, PyObject* value,
|
||||||
|
void* unused) {
|
||||||
|
Py_DECREF(self->handle_data);
|
||||||
|
Py_INCREF(value);
|
||||||
|
self->handle_data = value;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* EagerTensor_keras_mask(EagerTensor* self, void* unused) {
|
||||||
|
Py_INCREF(self->keras_mask);
|
||||||
|
return self->keras_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value,
|
||||||
|
void* unused) {
|
||||||
|
Py_DECREF(self->keras_mask);
|
||||||
|
Py_INCREF(value);
|
||||||
|
self->keras_mask = value;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
// Function `_copy_to_device`.
|
||||||
|
static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
|
||||||
|
PyObject* kwds) {
|
||||||
|
const char* kwlist[] = {"context", "device", nullptr};
|
||||||
|
PyObject* ctx = nullptr;
|
||||||
|
PyObject* dev = nullptr;
|
||||||
|
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist),
|
||||||
|
&ctx, &dev) ||
|
||||||
|
!ctx || !dev) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto handle = CopyToDevice(self->handle, ctx, dev);
|
||||||
|
return EagerTensorFromHandle(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function `_numpy`.
|
||||||
|
// Convert an EagerTensor to a Python numpy.ndarray object.
|
||||||
|
// The two may share underlying storage so changes to one may reflect in the
|
||||||
|
// other.
|
||||||
|
// Note that if `self` is not on CPU, we raise an Exception.
|
||||||
|
static PyObject* EagerTensor_numpy(EagerTensor* self) {
|
||||||
|
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
const tensorflow::Tensor* t =
|
||||||
|
TFE_TensorHandleUnderlyingTensorInHostMemory(self->handle, status.get());
|
||||||
|
if (TF_GetCode(status.get()) != TF_OK) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
PyObject* ret = nullptr;
|
||||||
|
auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
|
||||||
|
if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
|
||||||
|
Py_XDECREF(ret);
|
||||||
|
return nullptr;
|
||||||
|
} else {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter `device`.
|
||||||
|
static PyObject* EagerTensor_device(EagerTensor* self) {
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle));
|
||||||
|
#else
|
||||||
|
return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyGetSetDef EagerTensor_getseters[] = {
|
||||||
|
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
|
||||||
|
const_cast<char*>("_id"), nullptr},
|
||||||
|
{const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
|
||||||
|
const_cast<char*>("device"), nullptr},
|
||||||
|
{const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle,
|
||||||
|
(setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"),
|
||||||
|
nullptr},
|
||||||
|
{const_cast<char*>("_keras_mask"), (getter)EagerTensor_keras_mask,
|
||||||
|
(setter)EagerTensor_setkeras_mask, const_cast<char*>("_keras_mask"),
|
||||||
|
nullptr},
|
||||||
|
{nullptr} /* Sentinel */
|
||||||
|
};
|
||||||
|
|
||||||
|
static PyMethodDef EagerTensor_methods[] = {
|
||||||
|
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
|
||||||
|
PyDoc_STR("_numpy")},
|
||||||
|
{"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
|
||||||
|
PyDoc_STR("_datatype_enum")},
|
||||||
|
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
||||||
|
PyDoc_STR("_shape_tuple")},
|
||||||
|
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
|
||||||
|
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
|
||||||
|
{nullptr, nullptr},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Note that here we are trying to dynamically create a new class as a subclass
|
||||||
|
// of a "HEAPTYPE" class that is itself created in python code and passed in at
|
||||||
|
// runtime. This is fairly atypical and undocumented.
|
||||||
|
//
|
||||||
|
// We use the following strategy for this. Unfortunately, we have to use
|
||||||
|
// different approaches for python2.x vs python3.x
|
||||||
|
// For python2.x, we create the class as a static type and set its tp_base to
|
||||||
|
// the passed in type. Unfortunately setting tp_flags to include
|
||||||
|
// Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
|
||||||
|
// initialization of the underlying PyHeapTypeObject and not doing that leads to
|
||||||
|
// some random crashes especially during garbage collection.
|
||||||
|
// python3.x explicitly disables a static subclass of a HEAPTYPE base class.
|
||||||
|
// However it provides a new function, PyType_FromSpecWithBases, to create
|
||||||
|
// types dynamically.
|
||||||
|
|
||||||
|
// Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
|
||||||
|
PyTypeObject* EagerTensorType = nullptr;
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
static PyType_Slot EagerTensor_Type_slots[] = {
|
||||||
|
Py_tp_dealloc,
|
||||||
|
reinterpret_cast<void*>(EagerTensor_dealloc),
|
||||||
|
Py_tp_methods,
|
||||||
|
reinterpret_cast<void*>(EagerTensor_methods),
|
||||||
|
Py_tp_getset,
|
||||||
|
reinterpret_cast<void*>(EagerTensor_getseters),
|
||||||
|
Py_tp_init,
|
||||||
|
reinterpret_cast<void*>(EagerTensor_init),
|
||||||
|
0,
|
||||||
|
nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
|
||||||
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
|
||||||
|
EagerTensor_Type_slots};
|
||||||
|
#else
|
||||||
|
// TODO(agarwal): support active_trace.
|
||||||
|
static PyTypeObject _EagerTensorType = {
|
||||||
|
// clang-format off
|
||||||
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||||
|
// clang-format on
|
||||||
|
"EagerTensor", /* tp_name */
|
||||||
|
sizeof(EagerTensor), /* tp_basicsize */
|
||||||
|
0, /* tp_itemsize */
|
||||||
|
(destructor)EagerTensor_dealloc, /* tp_dealloc */
|
||||||
|
nullptr, /* tp_print */
|
||||||
|
nullptr, /* tp_getattr */
|
||||||
|
nullptr, /* tp_setattr */
|
||||||
|
nullptr, /* tp_compare */
|
||||||
|
nullptr, /* tp_repr */
|
||||||
|
nullptr, /* tp_as_number */
|
||||||
|
nullptr, /* tp_as_sequence */
|
||||||
|
nullptr, /* tp_as_mapping */
|
||||||
|
nullptr, /* tp_hash */
|
||||||
|
nullptr, /* tp_call */
|
||||||
|
nullptr, /* tp_str */
|
||||||
|
nullptr, /* tp_getattro */
|
||||||
|
nullptr, /* tp_setattro */
|
||||||
|
nullptr, /* tp_as_buffer */
|
||||||
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||||
|
nullptr, /* tp_doc */
|
||||||
|
nullptr, /* tp_traverse */
|
||||||
|
nullptr, /* tp_clear */
|
||||||
|
nullptr, /* tp_richcompare */
|
||||||
|
0, /* tp_weaklistoffset */
|
||||||
|
nullptr, /* tp_iter */
|
||||||
|
nullptr, /* tp_iternext */
|
||||||
|
EagerTensor_methods, /* tp_methods */
|
||||||
|
nullptr, /* tp_members */
|
||||||
|
EagerTensor_getseters, /* tp_getset */
|
||||||
|
nullptr, /* tp_base */
|
||||||
|
nullptr, /* tp_dict */
|
||||||
|
nullptr, /* tp_descr_get */
|
||||||
|
nullptr, /* tp_descr_set */
|
||||||
|
0, /* tp_dictoffset */
|
||||||
|
(initproc)EagerTensor_init, /* tp_init */
|
||||||
|
nullptr, /* tp_alloc */
|
||||||
|
nullptr, /* tp_new */
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // extern "C"
|
||||||
|
|
||||||
|
bool EagerTensor_CheckExact(const PyObject* o) {
|
||||||
|
return Py_TYPE(o) == EagerTensorType;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* EagerTensorHandle(const PyObject* o) {
|
||||||
|
return reinterpret_cast<const EagerTensor*>(o)->handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
|
||||||
|
if (handle == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
EagerTensor* t = reinterpret_cast<EagerTensor*>(
|
||||||
|
EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None));
|
||||||
|
if (t != nullptr) {
|
||||||
|
t->id = get_uid();
|
||||||
|
Py_INCREF(Py_None);
|
||||||
|
t->handle_data = Py_None;
|
||||||
|
Py_INCREF(Py_None);
|
||||||
|
t->keras_mask = Py_None;
|
||||||
|
t->handle = handle;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<PyObject*>(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||||
|
if (!PyType_Check(base_class)) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Expecting a class definition for `base_class` passed to ",
|
||||||
|
"TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
|
||||||
|
// EagerTensor to allow for the space usage of the base class.
|
||||||
|
PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
|
||||||
|
if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Unable to create subclass EagerTensor from base class ",
|
||||||
|
Py_TYPE(base_class)->tp_name,
|
||||||
|
". Need its size to be <= ", kMaxEagerTensorParentSize)
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (base_class_type->tp_itemsize != 0) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Unable to create subclass EagerTensor from base class ",
|
||||||
|
Py_TYPE(base_class)->tp_name,
|
||||||
|
" which supports variable length instances.")
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Py_INCREF(base_class);
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
PyObject* bases = PyTuple_New(1);
|
||||||
|
PyTuple_SET_ITEM(bases, 0, base_class);
|
||||||
|
EagerTensorType = reinterpret_cast<PyTypeObject*>(
|
||||||
|
PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (EagerTensorType == nullptr) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
_EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
|
||||||
|
|
||||||
|
if (PyType_Ready(&_EagerTensorType) < 0) {
|
||||||
|
if (PyErr_Occurred()) return nullptr;
|
||||||
|
PyErr_SetString(PyExc_RuntimeError,
|
||||||
|
"Error while creating EagerTensor type.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
EagerTensorType = &_EagerTensorType;
|
||||||
|
Py_INCREF(EagerTensorType);
|
||||||
|
#endif
|
||||||
|
// We disable instance based attribute lookup. Its not clear if these
|
||||||
|
// dictionaries are correctly initialized in the first place.
|
||||||
|
EagerTensorType->tp_dictoffset = 0;
|
||||||
|
return reinterpret_cast<PyObject*>(EagerTensorType);
|
||||||
|
}
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
@ -44,38 +45,46 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
|||||||
PyObject* attrs, TFE_OutputTensorHandles* outputs,
|
PyObject* attrs, TFE_OutputTensorHandles* outputs,
|
||||||
TF_Status* out_status);
|
TF_Status* out_status);
|
||||||
|
|
||||||
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
|
||||||
//
|
|
||||||
// The two may share underlying storage so changes to one may reflect in the
|
|
||||||
// other.
|
|
||||||
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status);
|
|
||||||
|
|
||||||
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
|
|
||||||
//
|
|
||||||
// The two may share underlying storage so changes to one may reflect in the
|
|
||||||
// other.
|
|
||||||
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj);
|
|
||||||
|
|
||||||
// Convert a Python sequence value to a TFE_TensorHandle.
|
|
||||||
//
|
|
||||||
// The dtype of the result is determined by the type of values found
|
|
||||||
// in *obj, *dtype is the desired type but it is only considered a
|
|
||||||
// hint. *dtype should be an integer representing the desired DataType
|
|
||||||
// enum value, or Py_None. Unlike TFE_Py_NumpyToTensorHandle, this
|
|
||||||
// always makes a copy. Returns nullptr and raises an exception on
|
|
||||||
// error.
|
|
||||||
// TODO(josh11b): Cast to dtype automatically.
|
|
||||||
TFE_TensorHandle* TFE_Py_SequenceToTensorHandle(PyObject* obj, PyObject* dtype);
|
|
||||||
|
|
||||||
// Registers e as the Exception class for handling not ok Status. Returns
|
// Registers e as the Exception class for handling not ok Status. Returns
|
||||||
// Py_None if registration succeeds, else throws a TypeError and returns NULL.
|
// Py_None if registration succeeds, else throws a TypeError and returns NULL.
|
||||||
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
|
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
|
||||||
|
|
||||||
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using the
|
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
|
||||||
// class registered via TFE_Py_RegisterExceptionClass) and returns -1.
|
// `exception` if not nullptr, else using the class registered via
|
||||||
int TFE_Py_MaybeRaiseException(TF_Status* status);
|
// TFE_Py_RegisterExceptionClass), and returns -1.
|
||||||
|
int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception);
|
||||||
|
|
||||||
|
// Returns 0 if 'status' is ok. Otherwise, raises an exception (using
|
||||||
|
// `exception` if not nullptr, else using the class registered via
|
||||||
|
// TFE_Py_RegisterExceptionClass), and returns -1.
|
||||||
|
int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
|
||||||
|
PyObject* exception);
|
||||||
|
|
||||||
// Returns the string associated with the passed-in python object.
|
// Returns the string associated with the passed-in python object.
|
||||||
char* TFE_GetPythonString(PyObject* o);
|
char* TFE_GetPythonString(PyObject* o);
|
||||||
|
|
||||||
|
// Returns a unique id on each call.
|
||||||
|
int64_t get_uid();
|
||||||
|
|
||||||
|
// Wraps the output of get_uid as a Python Long object. Ownership is passed to
|
||||||
|
// the caller.
|
||||||
|
PyObject* TFE_Py_UID();
|
||||||
|
|
||||||
|
// Deleter for Context objects, called from the Capsule that owns it.
|
||||||
|
void TFE_DeleteContextCapsule(PyObject* context);
|
||||||
|
|
||||||
|
// Returns true if o is an instance of EagerTensor, but not a subclass. Else
|
||||||
|
// returns false.
|
||||||
|
bool EagerTensor_CheckExact(const PyObject* o);
|
||||||
|
|
||||||
|
// Helper function to construct a new EagerTensor from a TFE_TensorHandle.
|
||||||
|
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
|
||||||
|
|
||||||
|
// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
|
||||||
|
TFE_TensorHandle* EagerTensorHandle(const PyObject* o);
|
||||||
|
|
||||||
|
// Creates the `EagerTensor` class by subclassing `base_class` and returns the
|
||||||
|
// newly created type, or nullptr on error.
|
||||||
|
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
|
||||||
|
|
||||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||||
|
@ -13,16 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// Must be included first.
|
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
|
||||||
|
|
||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
@ -320,6 +316,14 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Python subclass of Exception that is created on not ok Status.
|
||||||
|
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
||||||
|
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
||||||
|
|
||||||
|
static tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
|
||||||
|
static tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
||||||
@ -352,65 +356,6 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
|||||||
TFE_DeleteOp(op);
|
TFE_DeleteOp(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status) {
|
|
||||||
const tensorflow::Tensor* t =
|
|
||||||
TFE_TensorHandleUnderlyingTensorInHostMemory(h, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) {
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
}
|
|
||||||
PyObject* ret = nullptr;
|
|
||||||
auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
|
|
||||||
if (!cppstatus.ok()) {
|
|
||||||
TF_SetStatus(status, TF_Code(cppstatus.code()),
|
|
||||||
cppstatus.error_message().c_str());
|
|
||||||
}
|
|
||||||
if (ret != nullptr) return ret;
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// Python subclass of Exception that is created on not ok Status.
|
|
||||||
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
|
|
||||||
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
|
|
||||||
|
|
||||||
void PyRaiseException(TF_Code error_code, const char* msg) {
|
|
||||||
tensorflow::mutex_lock l(exception_class_mutex);
|
|
||||||
if (exception_class != nullptr) {
|
|
||||||
PyErr_SetObject(exception_class, Py_BuildValue("si", msg, error_code));
|
|
||||||
} else {
|
|
||||||
PyErr_SetString(PyExc_RuntimeError, msg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
|
|
||||||
tensorflow::Tensor t;
|
|
||||||
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
|
||||||
if (cppstatus.ok()) {
|
|
||||||
return TFE_NewTensorHandle(t);
|
|
||||||
} else {
|
|
||||||
PyRaiseException(TF_INVALID_ARGUMENT,
|
|
||||||
tensorflow::strings::StrCat(
|
|
||||||
"failed to convert numpy ndarray to a Tensor (",
|
|
||||||
cppstatus.error_message(), ")")
|
|
||||||
.c_str());
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_Py_SequenceToTensorHandle(PyObject* obj,
|
|
||||||
PyObject* dtype) {
|
|
||||||
tensorflow::Tensor t;
|
|
||||||
auto cppstatus = tensorflow::PySeqToTensor(obj, dtype, &t);
|
|
||||||
if (cppstatus.ok()) {
|
|
||||||
return TFE_NewTensorHandle(t);
|
|
||||||
} else {
|
|
||||||
PyRaiseException(TF_INVALID_ARGUMENT, cppstatus.error_message().c_str());
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
|
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
|
||||||
tensorflow::mutex_lock l(exception_class_mutex);
|
tensorflow::mutex_lock l(exception_class_mutex);
|
||||||
if (exception_class != nullptr) {
|
if (exception_class != nullptr) {
|
||||||
@ -429,9 +374,39 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int TFE_Py_MaybeRaiseException(TF_Status* status) {
|
int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
|
||||||
if (TF_GetCode(status) == TF_OK) return 0;
|
if (TF_GetCode(status) == TF_OK) return 0;
|
||||||
PyRaiseException(TF_GetCode(status), TF_Message(status));
|
const char* msg = TF_Message(status);
|
||||||
|
if (exception == nullptr) {
|
||||||
|
tensorflow::mutex_lock l(exception_class_mutex);
|
||||||
|
if (exception_class != nullptr) {
|
||||||
|
PyErr_SetObject(exception_class,
|
||||||
|
Py_BuildValue("si", msg, TF_GetCode(status)));
|
||||||
|
return -1;
|
||||||
|
} else {
|
||||||
|
exception = PyExc_RuntimeError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// May be update already set exception.
|
||||||
|
PyErr_SetString(exception, msg);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
|
||||||
|
PyObject* exception) {
|
||||||
|
if (status.ok()) return 0;
|
||||||
|
const char* msg = status.error_message().c_str();
|
||||||
|
if (exception == nullptr) {
|
||||||
|
tensorflow::mutex_lock l(exception_class_mutex);
|
||||||
|
if (exception_class != nullptr) {
|
||||||
|
PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code()));
|
||||||
|
return -1;
|
||||||
|
} else {
|
||||||
|
exception = PyExc_RuntimeError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// May be update already set exception.
|
||||||
|
PyErr_SetString(exception, msg);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -446,3 +421,18 @@ char* TFE_GetPythonString(PyObject* o) {
|
|||||||
#endif
|
#endif
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t get_uid() {
|
||||||
|
tensorflow::mutex_lock l(_uid_mutex);
|
||||||
|
return _uid++;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
|
||||||
|
|
||||||
|
void TFE_DeleteContextCapsule(PyObject* context) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_Context* ctx =
|
||||||
|
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
|
||||||
|
TFE_DeleteContext(ctx, status);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
@ -135,9 +135,9 @@ class Tape(object):
|
|||||||
# adding an explicit stack if this ever gets out of hand
|
# adding an explicit stack if this ever gets out of hand
|
||||||
self._delete_tensor_id(tensor_id)
|
self._delete_tensor_id(tensor_id)
|
||||||
|
|
||||||
def delete_trace(self, tensor):
|
def delete_trace(self, tensor_id):
|
||||||
"""Deletes any trace we have for this tensor."""
|
"""Deletes any trace we have for this tensor."""
|
||||||
self._delete_tensor_id(tid(tensor))
|
self._delete_tensor_id(tensor_id)
|
||||||
|
|
||||||
def export(self):
|
def export(self):
|
||||||
"""Exports the internal state of this tape.
|
"""Exports the internal state of this tape.
|
||||||
@ -237,10 +237,10 @@ def record_operation(op_type, output_tensors, input_tensors, side_outputs,
|
|||||||
backward_function)
|
backward_function)
|
||||||
|
|
||||||
|
|
||||||
def delete_trace(tensor):
|
def delete_trace(tensor_id):
|
||||||
"""Deletes traces for this Tensor from all tapes in the stack."""
|
"""Deletes traces for this Tensor from all tapes in the stack."""
|
||||||
for t in _tape_stack.stack:
|
for t in _tape_stack.stack:
|
||||||
t.delete_trace(tensor)
|
t.delete_trace(tensor_id)
|
||||||
|
|
||||||
|
|
||||||
def top_tape_watched_tensors():
|
def top_tape_watched_tensors():
|
||||||
|
@ -21,26 +21,90 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import core
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
|
def _create_tensor(value, device=None, dtype=None):
|
||||||
|
ctx = context.context()
|
||||||
|
if device is None:
|
||||||
|
device = ctx.device_name
|
||||||
|
if dtype is not None:
|
||||||
|
dtype = dtype.as_datatype_enum
|
||||||
|
try:
|
||||||
|
return ops.EagerTensor(
|
||||||
|
value, context=ctx._handle, device=device, dtype=dtype)
|
||||||
|
except core._NotOkStatusException as e: # pylint: disable=protected-access
|
||||||
|
raise core._status_to_exception(e.code, e.message)
|
||||||
|
|
||||||
|
|
||||||
class TFETensorTest(test_util.TensorFlowTestCase):
|
class TFETensorTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testScalarTensor(self):
|
def testScalarTensor(self):
|
||||||
t = constant_op.constant(3)
|
t = _create_tensor(3, dtype=dtypes.int32)
|
||||||
self.assertEqual(t.numpy(), constant_op.constant(np.array(3)).numpy())
|
self.assertEqual(t.numpy(), _create_tensor(np.array(3)).numpy())
|
||||||
self.assertEqual(dtypes.int32, t.dtype)
|
self.assertEqual(dtypes.int32, t.dtype)
|
||||||
self.assertEqual(0, t.shape.ndims)
|
self.assertEqual(0, t.shape.ndims)
|
||||||
self.assertAllEqual([], t.shape.as_list())
|
self.assertAllEqual([], t.shape.as_list())
|
||||||
|
self.assertIn("tf.Tensor", str(t))
|
||||||
|
self.assertIn("tf.Tensor", repr(t))
|
||||||
|
|
||||||
|
def testBadConstructorArgs(self):
|
||||||
|
ctx = context.context()
|
||||||
|
handle = ctx._handle
|
||||||
|
device = ctx.device_name
|
||||||
|
# Missing context.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, r"Required argument 'context' \(pos 2\) not found"):
|
||||||
|
ops.EagerTensor(1, device=device)
|
||||||
|
# Missing device.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, r"Required argument 'device' \(pos 3\) not found"):
|
||||||
|
ops.EagerTensor(1, context=handle)
|
||||||
|
# Bad dtype type.
|
||||||
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
"Expecting a DataType value for dtype. Got"):
|
||||||
|
ops.EagerTensor(1, context=handle, device=device, dtype="1")
|
||||||
|
# Following errors happen when trying to copy to GPU.
|
||||||
|
if not context.context().num_gpus():
|
||||||
|
self.skipTest("No GPUs found")
|
||||||
|
with ops.device("/device:GPU:0"):
|
||||||
|
device = ctx.device_name
|
||||||
|
# Bad context.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "Expecting a PyCapsule encoded context handle. Got"):
|
||||||
|
ops.EagerTensor(1.0, context=1, device=device)
|
||||||
|
# Bad device.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "Error parsing device argument to CopyToDevice"):
|
||||||
|
ops.EagerTensor(1.0, context=handle, device=1)
|
||||||
|
|
||||||
|
def testNumpyValue(self):
|
||||||
|
values = np.array([3.0])
|
||||||
|
t = _create_tensor(values)
|
||||||
|
self.assertAllEqual(values, t.numpy())
|
||||||
|
|
||||||
|
def testNumpyValueWithCast(self):
|
||||||
|
values = np.array([3.0], dtype=np.float32)
|
||||||
|
t = _create_tensor(values, dtype=dtypes.float64)
|
||||||
|
self.assertAllEqual(values, t.numpy())
|
||||||
|
ctx = context.context()
|
||||||
|
# Bad dtype value.
|
||||||
|
with self.assertRaisesRegexp(TypeError, "Invalid dtype argument value"):
|
||||||
|
ops.EagerTensor(
|
||||||
|
values, context=ctx._handle, device=ctx.device_name, dtype=12345)
|
||||||
|
|
||||||
|
def testNumpyOrderHandling(self):
|
||||||
|
n = np.array([[1, 2], [3, 4]], order="F")
|
||||||
|
t = _create_tensor(n)
|
||||||
|
self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
|
||||||
|
|
||||||
def testTensorAndNumpyMatrix(self):
|
def testTensorAndNumpyMatrix(self):
|
||||||
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
|
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
|
||||||
actual = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
actual = _create_tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||||
self.assertAllEqual(expected, actual.numpy())
|
self.assertAllEqual(expected, actual.numpy())
|
||||||
self.assertEqual(np.float32, actual.numpy().dtype)
|
self.assertEqual(np.float32, actual.numpy().dtype)
|
||||||
self.assertEqual(dtypes.float32, actual.dtype)
|
self.assertEqual(dtypes.float32, actual.dtype)
|
||||||
@ -48,56 +112,50 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testFloatDowncast(self):
|
def testFloatDowncast(self):
|
||||||
# Unless explicitly specified, float64->float32
|
# Unless explicitly specified, float64->float32
|
||||||
t = constant_op.constant(3.0)
|
t = _create_tensor(3.0)
|
||||||
self.assertEqual(dtypes.float32, t.dtype)
|
self.assertEqual(dtypes.float32, t.dtype)
|
||||||
t = constant_op.constant(3.0, dtype=dtypes.float64)
|
t = _create_tensor(3.0, dtype=dtypes.float64)
|
||||||
self.assertEqual(dtypes.float64, t.dtype)
|
self.assertEqual(dtypes.float64, t.dtype)
|
||||||
|
|
||||||
def testBool(self):
|
def testBool(self):
|
||||||
t = constant_op.constant(False)
|
t = _create_tensor(False)
|
||||||
if t:
|
if t:
|
||||||
self.assertFalse(True)
|
self.assertFalse(True)
|
||||||
|
|
||||||
def testIntDowncast(self):
|
def testIntDowncast(self):
|
||||||
t = constant_op.constant(3)
|
t = _create_tensor(3)
|
||||||
self.assertEqual(dtypes.int32, t.dtype)
|
self.assertEqual(dtypes.int32, t.dtype)
|
||||||
t = constant_op.constant(3, dtype=dtypes.int64)
|
t = _create_tensor(3, dtype=dtypes.int64)
|
||||||
self.assertEqual(dtypes.int64, t.dtype)
|
self.assertEqual(dtypes.int64, t.dtype)
|
||||||
t = constant_op.constant(2**33)
|
t = _create_tensor(2**33)
|
||||||
self.assertEqual(dtypes.int64, t.dtype)
|
self.assertEqual(dtypes.int64, t.dtype)
|
||||||
|
|
||||||
def testTensorCreationFailure(self):
|
def testTensorCreationFailure(self):
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(ValueError):
|
||||||
# Should fail because the each row of the Python object has a different
|
# Should fail because the each row of the Python object has a different
|
||||||
# number of columns.
|
# number of columns.
|
||||||
self.assertEqual(None, constant_op.constant([[1], [1, 2]]))
|
self.assertEqual(None, _create_tensor([[1], [1, 2]]))
|
||||||
|
|
||||||
def testNumpyOrderHandling(self):
|
|
||||||
n = np.array([[1, 2], [3, 4]], order="F")
|
|
||||||
t = constant_op.constant(n)
|
|
||||||
self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
|
|
||||||
|
|
||||||
def testMultiLineTensorStr(self):
|
def testMultiLineTensorStr(self):
|
||||||
t = constant_op.constant(np.eye(3))
|
t = _create_tensor(np.eye(3))
|
||||||
tensor_str = str(t)
|
tensor_str = str(t)
|
||||||
self.assertIn("shape=%s, dtype=%s" % (t.shape, t.dtype.name), tensor_str)
|
self.assertIn("shape=%s, dtype=%s" % (t.shape, t.dtype.name), tensor_str)
|
||||||
self.assertIn(str(t.numpy()), tensor_str)
|
self.assertIn(str(t.numpy()), tensor_str)
|
||||||
|
|
||||||
def testMultiLineTensorRepr(self):
|
def testMultiLineTensorRepr(self):
|
||||||
t = constant_op.constant(np.eye(3))
|
t = _create_tensor(np.eye(3))
|
||||||
tensor_repr = repr(t)
|
tensor_repr = repr(t)
|
||||||
self.assertTrue(tensor_repr.startswith("<"))
|
self.assertTrue(tensor_repr.startswith("<"))
|
||||||
self.assertTrue(tensor_repr.endswith(">"))
|
self.assertTrue(tensor_repr.endswith(">"))
|
||||||
self.assertIn(
|
self.assertIn("id=%d, shape=%s, dtype=%s, numpy=\n%r" %
|
||||||
"id=%d, shape=%s, dtype=%s, numpy=\n%r" % (
|
(t._id, t.shape, t.dtype.name, t.numpy()), tensor_repr)
|
||||||
t._id, t.shape, t.dtype.name, t.numpy()), tensor_repr)
|
|
||||||
|
|
||||||
def testTensorStrReprObeyNumpyPrintOptions(self):
|
def testTensorStrReprObeyNumpyPrintOptions(self):
|
||||||
orig_threshold = np.get_printoptions()["threshold"]
|
orig_threshold = np.get_printoptions()["threshold"]
|
||||||
orig_edgeitems = np.get_printoptions()["edgeitems"]
|
orig_edgeitems = np.get_printoptions()["edgeitems"]
|
||||||
np.set_printoptions(threshold=2, edgeitems=1)
|
np.set_printoptions(threshold=2, edgeitems=1)
|
||||||
|
|
||||||
t = constant_op.constant(np.arange(10, dtype=np.int32))
|
t = _create_tensor(np.arange(10, dtype=np.int32))
|
||||||
self.assertIn("[0 ..., 9]", str(t))
|
self.assertIn("[0 ..., 9]", str(t))
|
||||||
self.assertIn("[0, ..., 9]", repr(t))
|
self.assertIn("[0, ..., 9]", repr(t))
|
||||||
|
|
||||||
@ -105,30 +163,30 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
np.set_printoptions(threshold=orig_threshold, edgeitems=orig_edgeitems)
|
np.set_printoptions(threshold=orig_threshold, edgeitems=orig_edgeitems)
|
||||||
|
|
||||||
def testZeroDimTensorStr(self):
|
def testZeroDimTensorStr(self):
|
||||||
t = constant_op.constant(42)
|
t = _create_tensor(42)
|
||||||
self.assertIn("42, shape=(), dtype=int32", str(t))
|
self.assertIn("42, shape=(), dtype=int32", str(t))
|
||||||
|
|
||||||
def testZeroDimTensorRepr(self):
|
def testZeroDimTensorRepr(self):
|
||||||
t = constant_op.constant(42)
|
t = _create_tensor(42)
|
||||||
self.assertTrue(repr(t).startswith("<"))
|
self.assertTrue(repr(t).startswith("<"))
|
||||||
self.assertTrue(repr(t).endswith(">"))
|
self.assertTrue(repr(t).endswith(">"))
|
||||||
self.assertIn("id=%d, shape=(), dtype=int32, numpy=42" % t._id, repr(t))
|
self.assertIn("id=%d, shape=(), dtype=int32, numpy=42" % t._id, repr(t))
|
||||||
|
|
||||||
def testZeroSizeTensorStr(self):
|
def testZeroSizeTensorStr(self):
|
||||||
t = constant_op.constant(np.zeros(0, dtype=np.float32))
|
t = _create_tensor(np.zeros(0, dtype=np.float32))
|
||||||
self.assertIn("[], shape=(0,), dtype=float32", str(t))
|
self.assertIn("[], shape=(0,), dtype=float32", str(t))
|
||||||
|
|
||||||
def testZeroSizeTensorRepr(self):
|
def testZeroSizeTensorRepr(self):
|
||||||
t = constant_op.constant(np.zeros(0, dtype=np.float32))
|
t = _create_tensor(np.zeros(0, dtype=np.float32))
|
||||||
self.assertTrue(repr(t).startswith("<"))
|
self.assertTrue(repr(t).startswith("<"))
|
||||||
self.assertTrue(repr(t).endswith(">"))
|
self.assertTrue(repr(t).endswith(">"))
|
||||||
self.assertIn(
|
self.assertIn("id=%d, shape=(0,), dtype=float32, numpy=%r" % (t._id,
|
||||||
"id=%d, shape=(0,), dtype=float32, numpy=%r" % (t._id, t.numpy()),
|
t.numpy()),
|
||||||
repr(t))
|
repr(t))
|
||||||
|
|
||||||
def testStringTensor(self):
|
def testStringTensor(self):
|
||||||
t_np_orig = np.array([[b"a", b"ab"], [b"abc", b"abcd"]])
|
t_np_orig = np.array([[b"a", b"ab"], [b"abc", b"abcd"]])
|
||||||
t = constant_op.constant(t_np_orig)
|
t = _create_tensor(t_np_orig)
|
||||||
t_np = t.numpy()
|
t_np = t.numpy()
|
||||||
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
|
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
|
||||||
|
|
||||||
@ -137,9 +195,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
with ops.device("/device:GPU:0"):
|
with ops.device("/device:GPU:0"):
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
errors.InvalidArgumentError,
|
RuntimeError, "Can't copy Tensor with type string to device"):
|
||||||
"Can't copy Tensor with type string to device"):
|
_create_tensor("test string")
|
||||||
constant_op.constant("test string")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -84,26 +84,46 @@ def _eager_identity(tensor, ctx):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def convert_to_eager_tensor(t, ctx, dtype=None):
|
def convert_to_eager_tensor(value, ctx, dtype=None):
|
||||||
"""Converts the given `value` to an `EagerTensor`."""
|
"""Converts the given `value` to an `EagerTensor`.
|
||||||
if isinstance(t, ops.EagerTensor):
|
|
||||||
if dtype is not None and t.dtype != dtype:
|
Note that this function could return cached copies of created constants for
|
||||||
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
|
performance reasons.
|
||||||
return t
|
|
||||||
if isinstance(t, (int, float)):
|
Args:
|
||||||
|
value: value to convert to EagerTensor.
|
||||||
|
ctx: value of context.context().
|
||||||
|
dtype: optional desired dtype of the converted EagerTensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EagerTensor created from value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `dtype` is not compatible with the type of t.
|
||||||
|
"""
|
||||||
|
if isinstance(value, ops.EagerTensor):
|
||||||
|
if dtype is not None and value.dtype != dtype:
|
||||||
|
raise TypeError("Expected tensor with type %r not %r" % (
|
||||||
|
dtype, value.dtype))
|
||||||
|
return value
|
||||||
|
if dtype is not None:
|
||||||
|
dtype = dtype.as_datatype_enum
|
||||||
|
device = ctx.device_name
|
||||||
|
handle = ctx._handle # pylint: disable=protected-access
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
# Use a scalar cache. This will put each scalar of each type only once on
|
# Use a scalar cache. This will put each scalar of each type only once on
|
||||||
# each device. Scalars don't use much device memory but copying scalars can
|
# each device. Scalars don't use much device memory but copying scalars can
|
||||||
# trigger memcpys which are slow.
|
# trigger memcpys which are slow.
|
||||||
device = ctx.device_name
|
cache_key = device, value, dtype, type(value)
|
||||||
cache_key = device, t, dtype, type(t)
|
|
||||||
scalar_cache = ctx.scalar_cache()
|
scalar_cache = ctx.scalar_cache()
|
||||||
tensor = scalar_cache.get(cache_key, None)
|
tensor = scalar_cache.get(cache_key, None)
|
||||||
if tensor is not None:
|
if tensor is not None:
|
||||||
return tensor
|
return tensor
|
||||||
value = ops.EagerTensor(t, ctx, dtype=dtype)
|
t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
|
||||||
scalar_cache[cache_key] = value
|
scalar_cache[cache_key] = t
|
||||||
return value
|
return t
|
||||||
return ops.EagerTensor(t, ctx, dtype=dtype)
|
else:
|
||||||
|
return ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
||||||
@ -152,13 +172,13 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
|||||||
A Constant Tensor.
|
A Constant Tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError if shape is incorrectly specified or unsupported.
|
TypeError: if shape is incorrectly specified or unsupported.
|
||||||
"""
|
"""
|
||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
if not ctx.in_graph_mode():
|
if not ctx.in_graph_mode():
|
||||||
if shape is None:
|
|
||||||
return convert_to_eager_tensor(value, ctx, dtype)
|
|
||||||
t = convert_to_eager_tensor(value, ctx, dtype)
|
t = convert_to_eager_tensor(value, ctx, dtype)
|
||||||
|
if shape is None:
|
||||||
|
return t
|
||||||
shape = tensor_shape.as_shape(shape)
|
shape = tensor_shape.as_shape(shape)
|
||||||
if shape == t.shape:
|
if shape == t.shape:
|
||||||
return t
|
return t
|
||||||
|
@ -25,10 +25,9 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import function_pb2
|
from tensorflow.core.framework import function_pb2
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
@ -75,10 +74,6 @@ def tensor_id(tensor):
|
|||||||
return tensor._id # pylint: disable=protected-access
|
return tensor._id # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _in_gpu_device(ctx):
|
|
||||||
return "GPU" == ctx.device_spec.device_type
|
|
||||||
|
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
def _null_contextmanager():
|
def _null_contextmanager():
|
||||||
yield
|
yield
|
||||||
@ -171,16 +166,9 @@ def register_dense_tensor_like_type(tensor_type):
|
|||||||
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
|
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
|
||||||
|
|
||||||
|
|
||||||
_uid_counter = 0
|
|
||||||
_uid_lock = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def uid():
|
def uid():
|
||||||
"""A unique (within this program execution) integer."""
|
"""A unique (within this program execution) integer."""
|
||||||
with _uid_lock:
|
return c_api.TFE_Py_UID()
|
||||||
global _uid_counter
|
|
||||||
_uid_counter += 1
|
|
||||||
return _uid_counter
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
|
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
|
||||||
@ -584,127 +572,18 @@ class Tensor(_TensorLike):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _eager_cast(tensor_handle, src_type_enum, dest_type_enum, ctx):
|
# TODO(agarwal): consider getting rid of this.
|
||||||
"""Cast tensor_handle from src_type_enum to dest_type_enum."""
|
class _EagerTensorBase(Tensor):
|
||||||
# pylint: disable=protected-access
|
"""Base class for EagerTensor."""
|
||||||
try:
|
|
||||||
out_handle, = c_api.TFE_Py_Execute(
|
|
||||||
ctx._handle, b"/job:localhost/replica:0/task:0/device:CPU:0", b"Cast",
|
|
||||||
[tensor_handle], (b"SrcT", src_type_enum, b"DstT", dest_type_enum), 1)
|
|
||||||
except core._NotOkStatusException as e:
|
|
||||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
# TODO(josh11b): Should we support tracing or post_execution_callbacks here?
|
|
||||||
return out_handle
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _delete_trace(tid):
|
||||||
|
"""Helper function to be called by __del__ of the subclass."""
|
||||||
|
tape.delete_trace(tid)
|
||||||
|
|
||||||
# TODO(agarwal): rename to TensorHandle.
|
@property
|
||||||
class EagerTensor(Tensor):
|
def dtype(self):
|
||||||
"""A TensorFlow Eager Tensor."""
|
return dtypes.as_dtype(self._datatype_enum())
|
||||||
|
|
||||||
def __init__(self, value, ctx, dtype=None): # pylint: disable=super-init-not-called
|
|
||||||
"""Creates a Tensor object from a Python object or numpy array.
|
|
||||||
|
|
||||||
May share storage with the numpy array, in which case changes to the numpy
|
|
||||||
object will reflect
|
|
||||||
in the Tensor.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
value: A numpy.array or a Python object to create a Tensor for.
|
|
||||||
ctx: The value of context.context().
|
|
||||||
dtype: TensorFlow dtype for the returned Tensor. If None, one will be
|
|
||||||
automatically selected.
|
|
||||||
"""
|
|
||||||
# TODO(ashankar): Evaluate if we can and perhaps share code with
|
|
||||||
# tf.constant defined in
|
|
||||||
# https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
|
|
||||||
self._id = uid()
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
if isinstance(value, np.ndarray):
|
|
||||||
if dtype is not None:
|
|
||||||
npt = dtype.as_numpy_dtype
|
|
||||||
if npt != value.dtype:
|
|
||||||
value = value.astype(npt)
|
|
||||||
try:
|
|
||||||
value = np.asarray(value, order="C")
|
|
||||||
self._handle = c_api.TFE_Py_NumpyToTensorHandle(value)
|
|
||||||
except core._NotOkStatusException as e:
|
|
||||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
|
||||||
dtype = dtypes.as_dtype(c_api.TFE_TensorHandleDataType(self._handle))
|
|
||||||
else:
|
|
||||||
dtype_enum = None if dtype is None else dtype.as_datatype_enum
|
|
||||||
try:
|
|
||||||
self._handle = c_api.TFE_Py_SequenceToTensorHandle(value, dtype_enum)
|
|
||||||
except core._NotOkStatusException as e:
|
|
||||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
|
||||||
|
|
||||||
dtype_enum = c_api.TFE_TensorHandleDataType(self._handle)
|
|
||||||
dtype_actual = dtypes.as_dtype(dtype_enum)
|
|
||||||
if dtype is not None and dtype != dtype_actual:
|
|
||||||
self._handle = _eager_cast(self._handle, dtype_enum,
|
|
||||||
dtype.as_datatype_enum, ctx)
|
|
||||||
else:
|
|
||||||
dtype = dtype_actual
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
# Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
|
|
||||||
# memory. This change approximates the same behavior for eager execution -
|
|
||||||
# keeping int32 tensors in host memory.
|
|
||||||
#
|
|
||||||
# We do so to preclude the need for callers into such kernels from having to
|
|
||||||
# explicitly place the int32 tensors in host memory. For example, prior to
|
|
||||||
# this change one needed:
|
|
||||||
#
|
|
||||||
# with tf.device('/gpu:0'):
|
|
||||||
# ... # code here
|
|
||||||
# with tf.device('/cpu:0'):
|
|
||||||
# shape = tf.constant(...)
|
|
||||||
# y = tf.random_uniform(shape)
|
|
||||||
#
|
|
||||||
# Without the CPU device block tfe.ops.random_uniform would fail since the
|
|
||||||
# kernel expects the shape in host memory.
|
|
||||||
#
|
|
||||||
# After this change, we simplify the code:
|
|
||||||
#
|
|
||||||
# with tf.device('/gpu:0'):
|
|
||||||
# y = tf.random_uniform(...)
|
|
||||||
#
|
|
||||||
# The approximation is not exact there are GPU kernels which do not
|
|
||||||
# require host memory for int32 tensors. This will lead to a discrepancy
|
|
||||||
# between eager and graph execution.
|
|
||||||
# TODO(ashankar): Fix this.
|
|
||||||
if _in_gpu_device(ctx) and dtype != dtypes.int32:
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
device_name = ctx.device_name
|
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
|
||||||
self._handle = c_api.TFE_TensorHandleCopyToDevice(
|
|
||||||
self._handle, ctx._handle, device_name, status)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
self._dtype = dtype
|
|
||||||
|
|
||||||
# This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
|
||||||
# be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
|
|
||||||
# tensors, this will contain a serialized HandleData proto with shape
|
|
||||||
# inference metadata about shapes and dtypes of resources accessible from
|
|
||||||
# this handle.
|
|
||||||
self._handle_data = None
|
|
||||||
if core.active_trace() is not None:
|
|
||||||
core.active_trace().record_tensor("MANUAL",
|
|
||||||
tensor_id(self), self.device,
|
|
||||||
self.shape.num_elements())
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
tape.delete_trace(self)
|
|
||||||
if c_api is not None and c_api.TFE_DeleteTensorHandle is not None:
|
|
||||||
c_api.TFE_DeleteTensorHandle(self._handle)
|
|
||||||
if core.active_trace() is not None:
|
|
||||||
core.active_trace().delete_tensor(tensor_id(self))
|
|
||||||
except (AttributeError, TypeError):
|
|
||||||
# Sometimes deletion during program shutdown throws exception as other
|
|
||||||
# modules are no longer available.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _numpy_text(self, is_repr=False):
|
def _numpy_text(self, is_repr=False):
|
||||||
if self.dtype.is_numpy_compatible:
|
if self.dtype.is_numpy_compatible:
|
||||||
@ -715,19 +594,6 @@ class EagerTensor(Tensor):
|
|||||||
numpy_text = "\n" + numpy_text
|
numpy_text = "\n" + numpy_text
|
||||||
return numpy_text
|
return numpy_text
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(),
|
|
||||||
self.shape,
|
|
||||||
self.dtype.name)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
|
|
||||||
self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _override_operator(name, func):
|
|
||||||
setattr(EagerTensor, name, func)
|
|
||||||
|
|
||||||
def numpy(self):
|
def numpy(self):
|
||||||
"""Returns a numpy array with the same contents as the Tensor.
|
"""Returns a numpy array with the same contents as the Tensor.
|
||||||
|
|
||||||
@ -742,69 +608,13 @@ class EagerTensor(Tensor):
|
|||||||
A numpy array that may share memory with the Tensor object. Any changes
|
A numpy array that may share memory with the Tensor object. Any changes
|
||||||
to one may be reflected in the other.
|
to one may be reflected in the other.
|
||||||
"""
|
"""
|
||||||
# TODO(ashankar): This with status business seems expensive. Profile/avoid?
|
return self.as_cpu_tensor()._numpy() # pylint: disable=protected-access
|
||||||
cpu = self.as_cpu_tensor()
|
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
|
||||||
return c_api.TFE_Py_TensorHandleToNumpy(cpu._handle, status) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def _copy(self, ctx=None, device_name=None):
|
def _numpy(self):
|
||||||
"""Copies tensor to dest device."""
|
raise NotImplementedError()
|
||||||
# pylint: disable=protected-access
|
|
||||||
# Creates a new tensor on the dest device.
|
|
||||||
if ctx is None:
|
|
||||||
ctx = context.context()
|
|
||||||
if device_name is None:
|
|
||||||
device_name = ctx.device_name
|
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
|
||||||
h = c_api.TFE_TensorHandleCopyToDevice(self._handle, ctx._handle,
|
|
||||||
device_name, status)
|
|
||||||
new_tensor = _tensor_from_handle(h)
|
|
||||||
if core.active_trace() is not None:
|
|
||||||
core.active_trace().record_tensor("COPY",
|
|
||||||
tensor_id(new_tensor),
|
|
||||||
new_tensor.device,
|
|
||||||
new_tensor.shape.num_elements())
|
|
||||||
|
|
||||||
# Record the copy on tape and define backprop copy as well.
|
def _datatype_enum(self):
|
||||||
if not context.in_graph_mode():
|
raise NotImplementedError()
|
||||||
self_device = self.device
|
|
||||||
def grad_fun(dresult):
|
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
|
||||||
grad_h = c_api.TFE_TensorHandleCopyToDevice(
|
|
||||||
dresult._handle, ctx._handle, self_device, status)
|
|
||||||
return _tensor_from_handle(grad_h)
|
|
||||||
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
|
|
||||||
return new_tensor
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
def _dup(self):
|
|
||||||
return self._copy(device_name=self.device)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return c_api.TFE_TensorHandleDeviceName(self._handle)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
"""The shape of this Tensor as a TensorShape object."""
|
|
||||||
n = c_api.TFE_TensorHandleNumDims(self._handle)
|
|
||||||
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
|
|
||||||
# tensors (which have a valid, known shape). There were vague plans to
|
|
||||||
# change this so that the Tensor class can also represent Tensors that have
|
|
||||||
# not yet been computed.
|
|
||||||
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
|
|
||||||
# and also handle -1s returned by TFE_TensorHandleDim.
|
|
||||||
assert n >= 0, "See comment in source code"
|
|
||||||
return tensor_shape.TensorShape(
|
|
||||||
[c_api.TFE_TensorHandleDim(self._handle, x) for x in range(n)])
|
|
||||||
|
|
||||||
def get_shape(self):
|
|
||||||
"""Alias of Tensor.shape."""
|
|
||||||
return self.shape
|
|
||||||
|
|
||||||
def _shape_tuple(self):
|
def _shape_tuple(self):
|
||||||
"""The shape of this Tensor, as a tuple.
|
"""The shape of this Tensor, as a tuple.
|
||||||
@ -819,15 +629,62 @@ class EagerTensor(Tensor):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple with the shape.
|
tuple with the shape.
|
||||||
"""
|
"""
|
||||||
n = c_api.TFE_TensorHandleNumDims(self._handle)
|
raise NotImplementedError()
|
||||||
# As of May 2017, TFE_TensorHandle objects were always backed by concrete
|
|
||||||
# tensors (which have a valid, known shape). There were vague plans to
|
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
|
||||||
# change this so that the Tensor class can also represent Tensors that have
|
raise NotImplementedError()
|
||||||
# not yet been computed.
|
|
||||||
# If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
|
def __str__(self):
|
||||||
# and also handle -1s returned by TFE_TensorHandleDim.
|
return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(),
|
||||||
assert n >= 0, "See comment in source code"
|
self.shape,
|
||||||
return tuple(c_api.TFE_TensorHandleDim(self._handle, x) for x in range(n))
|
self.dtype.name)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
|
||||||
|
self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _override_operator(name, func):
|
||||||
|
setattr(_EagerTensorBase, name, func)
|
||||||
|
|
||||||
|
def _copy(self, ctx=None, device_name=None):
|
||||||
|
"""Copies tensor to dest device."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
# Creates a new tensor on the dest device.
|
||||||
|
if ctx is None:
|
||||||
|
ctx = context.context()
|
||||||
|
if device_name is None:
|
||||||
|
device_name = ctx.device_name
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
try:
|
||||||
|
new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
|
||||||
|
except core._NotOkStatusException as e:
|
||||||
|
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||||
|
if core.active_trace() is not None:
|
||||||
|
core.active_trace().record_tensor("COPY",
|
||||||
|
tensor_id(new_tensor),
|
||||||
|
new_tensor.device,
|
||||||
|
new_tensor.shape.num_elements())
|
||||||
|
|
||||||
|
# Record the copy on tape and define backprop copy as well.
|
||||||
|
if not context.in_graph_mode():
|
||||||
|
self_device = self.device
|
||||||
|
def grad_fun(dresult):
|
||||||
|
return dresult._copy(device_name=self_device)
|
||||||
|
tape.record_operation("_copy", [new_tensor], [self], [], grad_fun)
|
||||||
|
return new_tensor
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def _dup(self):
|
||||||
|
return self._copy(device_name=self.device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return tensor_shape.TensorShape(self._shape_tuple())
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
"""Alias of Tensor.shape."""
|
||||||
|
return self.shape
|
||||||
|
|
||||||
def _shape_as_list(self):
|
def _shape_as_list(self):
|
||||||
"""The shape of the tensor as a list."""
|
"""The shape of the tensor as a list."""
|
||||||
@ -899,35 +756,9 @@ class EagerTensor(Tensor):
|
|||||||
raise NotImplementedError("eval not supported for Eager Tensors.")
|
raise NotImplementedError("eval not supported for Eager Tensors.")
|
||||||
|
|
||||||
|
|
||||||
def _tensor_from_handle(handle):
|
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
|
||||||
"""'Private' constructor for the Tensor object.
|
# registers it with the current module.
|
||||||
|
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||||
The existence of a 'handle' is an implementation detail that should be hidden
|
|
||||||
from users of this module. Functions within this module do need to create a
|
|
||||||
Tensor object from a handle though.
|
|
||||||
|
|
||||||
One option would be to have an __init__(self, handle) method on the
|
|
||||||
Tensor class, but that would make the existence and use of a handle
|
|
||||||
'public'.
|
|
||||||
|
|
||||||
Instead, this function avoids exposing a Tensor.__init__ that understands
|
|
||||||
handles and yet allows functions within this module to create Tensor
|
|
||||||
objects from a handle.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
handle: A valid TFE_TensorHandle object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Tensor object.
|
|
||||||
"""
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
t = EagerTensor.__new__(EagerTensor)
|
|
||||||
t._id = uid()
|
|
||||||
t._handle = handle
|
|
||||||
t._dtype = dtypes.as_dtype(c_api.TFE_TensorHandleDataType(handle))
|
|
||||||
t._handle_data = None
|
|
||||||
return t
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
|
def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
|
||||||
|
@ -298,9 +298,12 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testConvertToTensorEager(self):
|
def testConvertToTensorEager(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
t = ops.EagerTensor(1, context.context())
|
t = constant_op.constant(1)
|
||||||
|
self.assertTrue(isinstance(t, ops.EagerTensor))
|
||||||
converted = ops.convert_to_tensor(t)
|
converted = ops.convert_to_tensor(t)
|
||||||
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
||||||
|
converted = ops.convert_to_tensor(1)
|
||||||
|
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
||||||
|
|
||||||
def testConvertToTensorNestedTuple(self):
|
def testConvertToTensorNestedTuple(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -103,8 +103,7 @@ class ConstantTest(test.TestCase):
|
|||||||
|
|
||||||
# This integer is larger than all non-infinite numbers representable
|
# This integer is larger than all non-infinite numbers representable
|
||||||
# by a double, raises an exception.
|
# by a double, raises an exception.
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
with self.assertRaisesRegexp(ValueError, "out-of-range integer"):
|
||||||
"out-of-range integer"):
|
|
||||||
constant_op.constant(10**310, dtypes_lib.float64)
|
constant_op.constant(10**310, dtypes_lib.float64)
|
||||||
|
|
||||||
def testInt32(self):
|
def testInt32(self):
|
||||||
@ -126,8 +125,7 @@ class ConstantTest(test.TestCase):
|
|||||||
self.assertAllClose(np.array(orig), tf_ans.numpy())
|
self.assertAllClose(np.array(orig), tf_ans.numpy())
|
||||||
|
|
||||||
# Out of range for an int64
|
# Out of range for an int64
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
with self.assertRaisesRegexp(ValueError, "out-of-range integer"):
|
||||||
"out-of-range integer"):
|
|
||||||
constant_op.constant([2**72])
|
constant_op.constant([2**72])
|
||||||
|
|
||||||
def testComplex64(self):
|
def testComplex64(self):
|
||||||
@ -240,14 +238,13 @@ class ConstantTest(test.TestCase):
|
|||||||
self._testAll((x, 1))
|
self._testAll((x, 1))
|
||||||
|
|
||||||
def testSparseValuesRaiseErrors(self):
|
def testSparseValuesRaiseErrors(self):
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
|
||||||
"non-rectangular Python sequence"):
|
|
||||||
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
|
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None):
|
with self.assertRaisesRegexp(ValueError, None):
|
||||||
constant_op.constant([[1, 2], [3]])
|
constant_op.constant([[1, 2], [3]])
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None):
|
with self.assertRaisesRegexp(ValueError, None):
|
||||||
constant_op.constant([[1, 2], [3], [4, 5]])
|
constant_op.constant([[1, 2], [3], [4, 5]])
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ class VariableScopeTest(test.TestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
variable_scope.get_variable("x4", initializer={})
|
variable_scope.get_variable("x4", initializer={})
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(ValueError):
|
||||||
variable_scope.get_variable("x4", initializer={})
|
variable_scope.get_variable("x4", initializer={})
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
@ -30,4 +30,11 @@ Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) {
|
|||||||
return Safe_TF_TensorPtr(tensor, TF_DeleteTensor);
|
return Safe_TF_TensorPtr(tensor, TF_DeleteTensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle) {
|
||||||
|
return Safe_TFE_TensorHandlePtr(handle, TFE_DeleteTensorHandle);
|
||||||
|
}
|
||||||
|
|
||||||
|
Safe_TF_StatusPtr make_safe(TF_Status* status) {
|
||||||
|
return Safe_TF_StatusPtr(status, TF_DeleteStatus);
|
||||||
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -36,6 +37,21 @@ typedef void (*TF_DeleteTensor_type)(TF_Tensor*);
|
|||||||
typedef std::unique_ptr<TF_Tensor, TF_DeleteTensor_type> Safe_TF_TensorPtr;
|
typedef std::unique_ptr<TF_Tensor, TF_DeleteTensor_type> Safe_TF_TensorPtr;
|
||||||
Safe_TF_TensorPtr make_safe(TF_Tensor* tensor);
|
Safe_TF_TensorPtr make_safe(TF_Tensor* tensor);
|
||||||
|
|
||||||
|
// Safe containers for an owned TFE_TensorHandle. On destruction, the handle
|
||||||
|
// will be deleted by TFE_DeleteTensorHandle. Note: can't use
|
||||||
|
// decltype(&TFE_DeleteTensorHandle) due to SWIG
|
||||||
|
typedef void (*TFE_DeleteTensorHandle_type)(TFE_TensorHandle*);
|
||||||
|
typedef std::unique_ptr<TFE_TensorHandle, TFE_DeleteTensorHandle_type>
|
||||||
|
Safe_TFE_TensorHandlePtr;
|
||||||
|
Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle);
|
||||||
|
|
||||||
|
// Safe containers for an owned TF_Status. On destruction, the handle
|
||||||
|
// will be deleted by TF_DeleteStatus. Note: can't use
|
||||||
|
// decltype(&TF_DeleteStatus) due to SWIG
|
||||||
|
typedef void (*TF_DeleteStatus_type)(TF_Status*);
|
||||||
|
typedef std::unique_ptr<TF_Status, TF_DeleteStatus_type> Safe_TF_StatusPtr;
|
||||||
|
Safe_TF_StatusPtr make_safe(TF_Status* status);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
|
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
|
||||||
|
@ -15,24 +15,16 @@ limitations under the License.
|
|||||||
|
|
||||||
%ignore "";
|
%ignore "";
|
||||||
|
|
||||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
|
||||||
%rename("%s") TFE_Py_NumpyToTensorHandle;
|
|
||||||
%rename("%s") TFE_Py_SequenceToTensorHandle;
|
|
||||||
%rename("%s") TFE_Py_AllEqualInt64;
|
|
||||||
%rename("%s") TFE_NewContext;
|
%rename("%s") TFE_NewContext;
|
||||||
%rename("%s") TFE_DeleteContext;
|
%rename("%s") TFE_DeleteContext;
|
||||||
%rename("%s") TFE_ContextListDevices;
|
%rename("%s") TFE_ContextListDevices;
|
||||||
%rename("%s") TFE_TensorHandleDataType;
|
|
||||||
%rename("%s") TFE_TensorHandleNumDims;
|
|
||||||
%rename("%s") TFE_DeleteTensorHandle;
|
|
||||||
%rename("%s") TFE_Py_Execute;
|
|
||||||
%rename("%s") TFE_ContextAddFunctionDef;
|
%rename("%s") TFE_ContextAddFunctionDef;
|
||||||
%rename("%s") TFE_TensorHandleDim;
|
|
||||||
%rename("%s") TFE_TensorHandleDeviceName;
|
|
||||||
%rename("%s") TFE_TensorHandleCopyToDevice;
|
|
||||||
%rename("%s") TFE_NewOp;
|
%rename("%s") TFE_NewOp;
|
||||||
%rename("%s") TFE_Py_TensorHandleToNumpy;
|
|
||||||
%rename("%s") TFE_OpGetAttrType;
|
%rename("%s") TFE_OpGetAttrType;
|
||||||
|
%rename("%s") TFE_Py_InitEagerTensor;
|
||||||
|
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||||
|
%rename("%s") TFE_Py_Execute;
|
||||||
|
%rename("%s") TFE_Py_UID;
|
||||||
|
|
||||||
|
|
||||||
%{
|
%{
|
||||||
@ -79,6 +71,18 @@ limitations under the License.
|
|||||||
$1 = TFE_GetPythonString($input);
|
$1 = TFE_GetPythonString($input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
%typemap(in) (TFE_Context*) {
|
||||||
|
$1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr);
|
||||||
|
|
||||||
|
}
|
||||||
|
%typemap(out) (TFE_Context*) {
|
||||||
|
if ($1 == nullptr) {
|
||||||
|
SWIG_fail;
|
||||||
|
} else {
|
||||||
|
$result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
%include "tensorflow/c/eager/c_api.h"
|
%include "tensorflow/c/eager/c_api.h"
|
||||||
|
|
||||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
||||||
@ -95,15 +99,13 @@ limitations under the License.
|
|||||||
if (!elem) {
|
if (!elem) {
|
||||||
SWIG_fail;
|
SWIG_fail;
|
||||||
}
|
}
|
||||||
void* thp = nullptr;
|
if (EagerTensor_CheckExact(elem)) {
|
||||||
int res = SWIG_ConvertPtr(elem, &thp,
|
(*$1)[i] = EagerTensorHandle(elem);
|
||||||
$descriptor(TFE_TensorHandle*), 0 | 0);
|
} else {
|
||||||
if (!SWIG_IsOK(res)) {
|
SWIG_exception_fail(SWIG_TypeError,
|
||||||
SWIG_exception_fail(SWIG_ArgError(res),
|
|
||||||
"provided list of inputs contains objects other "
|
"provided list of inputs contains objects other "
|
||||||
"than 'TFE_TensorHandle*'");
|
"than 'EagerTensor'");
|
||||||
}
|
}
|
||||||
(*$1)[i] = reinterpret_cast<TFE_TensorHandle*>(thp);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -129,45 +131,32 @@ limitations under the License.
|
|||||||
}
|
}
|
||||||
|
|
||||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
|
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
|
||||||
if (TFE_Py_MaybeRaiseException($2)) {
|
if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) {
|
||||||
SWIG_fail;
|
SWIG_fail;
|
||||||
} else {
|
} else {
|
||||||
int num_outputs = $1->size();
|
int num_outputs = $1->size();
|
||||||
$result = PyList_New(num_outputs);
|
$result = PyList_New(num_outputs);
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
PyList_SetItem($result, i, SWIG_NewPointerObj(SWIG_as_voidptr($1->at(i)),
|
PyObject *output;
|
||||||
$descriptor(TFE_TensorHandle*),
|
output = EagerTensorFromHandle($1->at(i));
|
||||||
0 | 0));
|
PyList_SetItem($result, i, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that we need to use a typemap for TFE_TensorHandle* so that we can call
|
|
||||||
// SWIG_fail in case the value is nullptr. Otherwise SWIG will wrap the
|
|
||||||
// nullptr and return it to python as an opaque object, and python does not
|
|
||||||
// know that it needs to check if an Exception has been raised.
|
|
||||||
// TODO(agarwal): check if we can get rid of this typemap.
|
|
||||||
%typemap(out) (TFE_TensorHandle*) {
|
|
||||||
if ($1 == nullptr) {
|
|
||||||
SWIG_fail;
|
|
||||||
} else {
|
|
||||||
$result = SWIG_NewPointerObj(SWIG_as_voidptr($1),
|
|
||||||
$descriptor(TFE_TensorHandle*), 0 | 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
%include "tensorflow/python/eager/pywrap_tfe.h"
|
%include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
|
||||||
|
|
||||||
// Clear all typemaps127
|
// Clear all typemaps.
|
||||||
%typemap(out) TF_DataType;
|
%typemap(out) TF_DataType;
|
||||||
%typemap(out) int64_t;
|
%typemap(out) int64_t;
|
||||||
%typemap(out) TF_AttrType;
|
%typemap(out) TF_AttrType;
|
||||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||||
%typemap(argout) unsigned char* is_list;
|
%typemap(argout) unsigned char* is_list;
|
||||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp);
|
%typemap(in) (TFE_Context*);
|
||||||
|
%typemap(out) (TFE_Context*);
|
||||||
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
|
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
|
||||||
%typemap(in, numinputs=0) TF_Status *out_status;
|
%typemap(in, numinputs=0) TF_Status *out_status;
|
||||||
%typemap(freearg) (TF_Status* out_status);
|
%typemap(freearg) (TF_Status* out_status);
|
||||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
|
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
|
||||||
%typemap(out) (TFE_TensorHandle*);
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user