Ensure context is populated for constant tensors

PiperOrigin-RevId: 286059397
Change-Id: Iaa771195df8c1f1cc4dc33a136c6eb75a9181f44
This commit is contained in:
Gaurav Jain 2019-12-17 14:24:21 -08:00 committed by TensorFlower Gardener
parent a845493244
commit 19318ce1d5

View File

@ -74,12 +74,13 @@ TFE_Context* GetContextHandle(PyObject* py_context) {
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
// The two may share underlying storage so changes to one may reflect in the
// other.
TFE_TensorHandle* NumpyToTFE_TensorHandle(PyObject* obj) {
TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
tensorflow::TensorHandle* handle;
tensorflow::Tensor t;
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
}
if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError,
@ -251,14 +252,16 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
#undef RETURN_ERROR
}
TFE_TensorHandle* PySeqToTFE_TensorHandle(PyObject* value, DataType dtype) {
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* value,
DataType dtype) {
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Tensor t;
// TODO(josh11b): Have PySeqToTensor set python errors instead of
// returning Status.
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
}
if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
@ -312,9 +315,9 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
}
value = safe_value.get();
}
handle = make_safe(NumpyToTFE_TensorHandle(value));
handle = make_safe(NumpyToTFE_TensorHandle(ctx, value));
} else {
handle = make_safe(PySeqToTFE_TensorHandle(value, dtype));
handle = make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
}
if (handle == nullptr) return nullptr;