Ensure context is populated for constant tensors
PiperOrigin-RevId: 286059397 Change-Id: Iaa771195df8c1f1cc4dc33a136c6eb75a9181f44
This commit is contained in:
parent
a845493244
commit
19318ce1d5
@ -74,12 +74,13 @@ TFE_Context* GetContextHandle(PyObject* py_context) {
|
||||
// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
TFE_TensorHandle* NumpyToTFE_TensorHandle(PyObject* obj) {
|
||||
TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
tensorflow::Tensor t;
|
||||
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
||||
if (cppstatus.ok()) {
|
||||
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
|
||||
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
|
||||
}
|
||||
if (!cppstatus.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
@ -251,14 +252,16 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||
#undef RETURN_ERROR
|
||||
}
|
||||
|
||||
TFE_TensorHandle* PySeqToTFE_TensorHandle(PyObject* value, DataType dtype) {
|
||||
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* value,
|
||||
DataType dtype) {
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Tensor t;
|
||||
// TODO(josh11b): Have PySeqToTensor set python errors instead of
|
||||
// returning Status.
|
||||
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
|
||||
if (cppstatus.ok()) {
|
||||
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
|
||||
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
|
||||
}
|
||||
if (!cppstatus.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
|
||||
@ -312,9 +315,9 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
||||
}
|
||||
value = safe_value.get();
|
||||
}
|
||||
handle = make_safe(NumpyToTFE_TensorHandle(value));
|
||||
handle = make_safe(NumpyToTFE_TensorHandle(ctx, value));
|
||||
} else {
|
||||
handle = make_safe(PySeqToTFE_TensorHandle(value, dtype));
|
||||
handle = make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
|
||||
}
|
||||
|
||||
if (handle == nullptr) return nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user