Allow creating tensors from numpy arrays, and other various constants - try #2
Allow type-inference from a different input tensor, similar to args_to_matching_eager. - Update TFE_Py_TensorShapeSlice to take tuples. - Update int values to allow int/long in py2 END_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 192184809 PiperOrigin-RevId: 193696790
This commit is contained in:
parent
5fbb1feecd
commit
712bbc5d7b
@ -60,42 +60,6 @@ TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) {
|
||||
}
|
||||
}
|
||||
|
||||
// Casts data referred to by `handle` from type `src_type_enum` to type
|
||||
// `dst_type_enum`.
|
||||
TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||
TF_DataType src_type_enum,
|
||||
TF_DataType dst_type_enum, TF_Status* out_status) {
|
||||
if (ctx == nullptr) return nullptr;
|
||||
const char* op_name = "Cast";
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
|
||||
#define RETURN_ERROR \
|
||||
{ \
|
||||
TFE_DeleteOp(op); \
|
||||
return nullptr; \
|
||||
}
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpSetDevice(op, device_name, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpAddInput(op, handle, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpSetAttrType(op, "SrcT", src_type_enum);
|
||||
TFE_OpSetAttrType(op, "DstT", dst_type_enum);
|
||||
TFE_TensorHandle* output = nullptr;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(op, &output, &num_outputs, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
|
||||
output == nullptr) {
|
||||
if (output != nullptr) {
|
||||
TFE_DeleteTensorHandle(output);
|
||||
}
|
||||
RETURN_ERROR
|
||||
}
|
||||
TFE_DeleteOp(op);
|
||||
return output;
|
||||
#undef RETURN_ERROR
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
|
||||
PyObject* dev) {
|
||||
const char* device = "";
|
||||
@ -161,6 +125,100 @@ PyObject* PyIntFromDataType(TF_DataType l) {
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tensorflow {
|
||||
// Casts data referred to by `handle` from type `src_type_enum` to type
|
||||
// `dst_type_enum`.
|
||||
TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||
TF_DataType src_type_enum,
|
||||
TF_DataType dst_type_enum, TF_Status* out_status) {
|
||||
if (ctx == nullptr) return nullptr;
|
||||
const char* op_name = "Cast";
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
|
||||
#define RETURN_ERROR \
|
||||
{ \
|
||||
TFE_DeleteOp(op); \
|
||||
return nullptr; \
|
||||
}
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpSetDevice(op, device_name, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpAddInput(op, handle, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
|
||||
TFE_OpSetAttrType(op, "SrcT", src_type_enum);
|
||||
TFE_OpSetAttrType(op, "DstT", dst_type_enum);
|
||||
TFE_TensorHandle* output = nullptr;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(op, &output, &num_outputs, out_status);
|
||||
if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
|
||||
output == nullptr) {
|
||||
if (output != nullptr) {
|
||||
TFE_DeleteTensorHandle(output);
|
||||
}
|
||||
RETURN_ERROR
|
||||
}
|
||||
TFE_DeleteOp(op);
|
||||
return output;
|
||||
#undef RETURN_ERROR
|
||||
}
|
||||
|
||||
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype) {
|
||||
int desired_dtype = -1;
|
||||
if (dtype != Py_None) {
|
||||
if (!PyIntToDataType(dtype, &desired_dtype)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expecting a DataType value for dtype. Got ",
|
||||
Py_TYPE(dtype)->tp_name)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (PyArray_Check(value)) {
|
||||
int desired_np_dtype = -1;
|
||||
if (desired_dtype >= 0) {
|
||||
if (!tensorflow::TF_DataType_to_PyArray_TYPE(
|
||||
static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
|
||||
.ok()) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Invalid dtype argument value ", desired_dtype)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
||||
int current_np_dtype = PyArray_TYPE(array);
|
||||
auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
|
||||
if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
|
||||
!PyArray_ISCARRAY(array)) {
|
||||
int new_dtype =
|
||||
desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
|
||||
safe_value = tensorflow::make_safe(
|
||||
PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
|
||||
NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
if (safe_value == nullptr) {
|
||||
PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
|
||||
return nullptr;
|
||||
}
|
||||
value = safe_value.get();
|
||||
}
|
||||
return NumpyToTensorHandle(value);
|
||||
} else {
|
||||
tensorflow::Tensor t;
|
||||
// TODO(josh11b): Have PySeqToTensor set python errors instead of
|
||||
// returning Status.
|
||||
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
|
||||
if (!cppstatus.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return TFE_NewTensorHandle(t);
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
extern "C" {
|
||||
|
||||
static const int kMaxEagerTensorParentSize = 64;
|
||||
@ -230,61 +288,16 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
tensorflow::Safe_TFE_TensorHandlePtr handle =
|
||||
tensorflow::make_safe(static_cast<TFE_TensorHandle*>(nullptr));
|
||||
PyErr_Clear();
|
||||
if (PyArray_Check(value)) {
|
||||
int desired_np_dtype = -1;
|
||||
if (desired_dtype >= 0) {
|
||||
if (!tensorflow::TF_DataType_to_PyArray_TYPE(
|
||||
static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
|
||||
.ok()) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Invalid dtype argument value ", desired_dtype)
|
||||
.c_str());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
||||
int current_np_dtype = PyArray_TYPE(array);
|
||||
auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
|
||||
if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
|
||||
!PyArray_ISCARRAY(array)) {
|
||||
int new_dtype =
|
||||
desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
|
||||
safe_value = tensorflow::make_safe(
|
||||
PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
|
||||
NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
|
||||
if (PyErr_Occurred()) return -1;
|
||||
if (safe_value == nullptr) {
|
||||
PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
|
||||
return -1;
|
||||
}
|
||||
value = safe_value.get();
|
||||
}
|
||||
handle = tensorflow::make_safe(NumpyToTensorHandle(value));
|
||||
} else {
|
||||
tensorflow::Tensor t;
|
||||
// TODO(josh11b): Have PySeqToTensor set python errors instead of
|
||||
// returning Status.
|
||||
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
|
||||
if (!cppstatus.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
|
||||
return -1;
|
||||
}
|
||||
handle = tensorflow::make_safe(TFE_NewTensorHandle(t));
|
||||
}
|
||||
if (PyErr_Occurred()) return -1;
|
||||
if (handle == nullptr) {
|
||||
PyErr_SetString(PyExc_ValueError, "Error while creating an EagerTensor");
|
||||
return -1;
|
||||
}
|
||||
tensorflow::Safe_TFE_TensorHandlePtr handle =
|
||||
tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
|
||||
tensorflow::ConvertToEagerTensor(value, dtype)));
|
||||
if (handle == nullptr) return -1;
|
||||
TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
|
||||
handle = tensorflow::make_safe(
|
||||
EagerCast(GetContext(context), handle.get(), handle_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), self->status));
|
||||
handle = tensorflow::make_safe(tensorflow::EagerCast(
|
||||
GetContext(context), handle.get(), handle_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), self->status));
|
||||
if (TF_GetCode(self->status) != TF_OK) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
tensorflow::strings::StrCat(
|
||||
@ -701,12 +714,12 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||
return reinterpret_cast<PyObject*>(EagerTensorType);
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
|
||||
if (!PyList_Check(tensor_list)) {
|
||||
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
|
||||
if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"tensor_list argument must be a list. Got \"",
|
||||
Py_TYPE(tensor_list)->tp_name, "\"")
|
||||
"tensors argument must be a list or a tuple. Got \"",
|
||||
Py_TYPE(tensors)->tp_name, "\"")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
@ -720,14 +733,14 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Py_ssize_t num_tensors = PyList_Size(tensor_list);
|
||||
Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
|
||||
int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
|
||||
auto tensor = tensorflow::make_safe(TF_AllocateTensor(
|
||||
TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
|
||||
int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
|
||||
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
|
||||
PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i);
|
||||
PyObject* tensor_obj = PySequence_Fast_GET_ITEM(tensors, i);
|
||||
if (!EagerTensor_CheckExact(tensor_obj)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
|
@ -22,4 +22,14 @@ limitations under the License.
|
||||
bool EagerTensor_CheckExact(const PyObject* o);
|
||||
tensorflow::int64 EagerTensor_id(const PyObject* tensor);
|
||||
|
||||
namespace tensorflow {
|
||||
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
|
||||
|
||||
// TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
|
||||
// execute TFE Ops) to a separate common library.
|
||||
TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||
TF_DataType src_type_enum,
|
||||
TF_DataType dst_type_enum, TF_Status* out_status);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
|
||||
|
@ -186,16 +186,16 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
|
||||
// Returns the set of variables watched by the given tape.
|
||||
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
|
||||
|
||||
// Returns an EagerTensor of dimension [len(`tensor_list`)] containing
|
||||
// the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
|
||||
// Returns an EagerTensor of dimension [len(`tensors`)] containing
|
||||
// the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
|
||||
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
|
||||
// `tensor_list`. For example, if `tensor_list` contains tensors of with shapes
|
||||
// `tensors`. For example, if `tensors` contains tensors of with shapes
|
||||
// [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with
|
||||
// `slice_dim` equal to 1 will return [2, 5, 7].
|
||||
// On error, returns nullptr and sets python exception.
|
||||
// REQUIRES: `tensor_list` is a python list of EagerTensors
|
||||
// REQUIRES: `tensors` is a python list/tuple of EagerTensors
|
||||
// REQUIRES: `slice_dim` is non-negative and smaller than the rank of all
|
||||
// tensors in `tensor_list`.
|
||||
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim);
|
||||
// tensors in `tensors`.
|
||||
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
|
@ -38,6 +38,54 @@ using tensorflow::strings::Printf;
|
||||
|
||||
namespace {
|
||||
|
||||
struct InputInfo {
|
||||
InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
|
||||
|
||||
int i;
|
||||
bool is_list = false;
|
||||
};
|
||||
|
||||
using AttrToInputsMap =
|
||||
tensorflow::gtl::FlatMap<string,
|
||||
tensorflow::gtl::InlinedVector<InputInfo, 4>>;
|
||||
|
||||
tensorflow::mutex all_attr_to_input_maps_lock(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
|
||||
static auto* all_attr_to_input_maps =
|
||||
new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
|
||||
return all_attr_to_input_maps;
|
||||
}
|
||||
|
||||
AttrToInputsMap* GetAttrToInputsMap(const tensorflow::OpDef& op_def) {
|
||||
tensorflow::mutex_lock l(all_attr_to_input_maps_lock);
|
||||
auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
|
||||
|
||||
auto* output =
|
||||
tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
|
||||
if (output != nullptr) {
|
||||
return output;
|
||||
}
|
||||
|
||||
std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
|
||||
|
||||
// Store a list of InputIndex -> List of corresponding inputs.
|
||||
for (int i = 0; i < op_def.input_arg_size(); i++) {
|
||||
if (!op_def.input_arg(i).type_attr().empty()) {
|
||||
auto it = m->find(op_def.input_arg(i).type_attr());
|
||||
if (it == m->end()) {
|
||||
it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
|
||||
}
|
||||
it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
|
||||
}
|
||||
}
|
||||
|
||||
auto* retval = m.get();
|
||||
(*all_attr_to_input_maps)[op_def.name()] = m.release();
|
||||
|
||||
return retval;
|
||||
}
|
||||
|
||||
struct FastPathOpExecInfo {
|
||||
TFE_Context* ctx;
|
||||
const char* device_name;
|
||||
@ -53,6 +101,14 @@ struct FastPathOpExecInfo {
|
||||
// The op type name of the main op being executed.
|
||||
PyObject* op_name;
|
||||
PyObject* callbacks;
|
||||
|
||||
// All the args passed into the FastPathOpExecInfo.
|
||||
PyObject* args;
|
||||
|
||||
// DTypes can come from another input that has the same attr. So build that
|
||||
// map.
|
||||
const AttrToInputsMap* attr_to_inputs_map;
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
|
||||
};
|
||||
|
||||
#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
|
||||
@ -76,12 +132,29 @@ PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
|
||||
PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
|
||||
#else
|
||||
PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
|
||||
PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong)
|
||||
PARSE_VALUE(ParseInt64LongValue, int64_t, PyLong_Check, PyLong_AsLong)
|
||||
#endif
|
||||
PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
|
||||
#undef PARSE_VALUE
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
|
||||
int64_t* value) {
|
||||
if (PyInt_Check(py_value)) {
|
||||
*value = static_cast<int64_t>(PyInt_AsLong(py_value));
|
||||
return true;
|
||||
} else if (PyLong_Check(py_value)) {
|
||||
*value = static_cast<int64_t>(PyLong_AsLong(py_value));
|
||||
return true;
|
||||
}
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
|
||||
", got ", py_value->ob_type->tp_name)
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
Py_ssize_t TensorShapeNumDims(PyObject* value) {
|
||||
const auto size = PySequence_Size(value);
|
||||
if (size == -1) {
|
||||
@ -234,7 +307,7 @@ bool SetOpAttrList(
|
||||
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
|
||||
// Copy the input dims into the buffer and set dims to point to
|
||||
// the start of each list's dims.
|
||||
std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]);
|
||||
std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
|
||||
std::unique_ptr<int[]> num_dims(new int[num_values]);
|
||||
int64_t* offset = buffer.get();
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
@ -296,7 +369,7 @@ void SetOpAttrListDefault(
|
||||
TF_Status* status) {
|
||||
if (type == TF_ATTR_STRING) {
|
||||
int num_values = attr.default_value().list().s_size();
|
||||
std::unique_ptr<const char* []> values(new const char*[num_values]);
|
||||
std::unique_ptr<const char*[]> values(new const char*[num_values]);
|
||||
(*attr_list_sizes)[key] = num_values;
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
values[i] = attr.default_value().list().s(i).data();
|
||||
@ -349,7 +422,7 @@ void SetOpAttrListDefault(
|
||||
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
|
||||
// Copy the input dims into the buffer and set dims to point to
|
||||
// the start of each list's dims.
|
||||
std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]);
|
||||
std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
|
||||
std::unique_ptr<int[]> num_dims(new int[num_values]);
|
||||
int64_t* offset = buffer.get();
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
@ -369,7 +442,7 @@ void SetOpAttrListDefault(
|
||||
} else if (type == TF_ATTR_FUNC) {
|
||||
int num_values = attr.default_value().list().func_size();
|
||||
(*attr_list_sizes)[key] = num_values;
|
||||
std::unique_ptr<const TFE_Op* []> funcs(new const TFE_Op*[num_values]);
|
||||
std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
|
||||
}
|
||||
@ -1399,10 +1472,39 @@ PyObject* GetPythonObjectFromString(const char* s) {
|
||||
#endif
|
||||
}
|
||||
|
||||
PyObject* GetPythonObjectFromInt(int num) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyLong_FromLong(num);
|
||||
#else
|
||||
return PyInt_FromLong(num);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CheckResourceVariable(PyObject* item) {
|
||||
return PyObject_TypeCheck(item, resource_variable_type);
|
||||
}
|
||||
|
||||
bool IsNumberType(PyObject* item) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyFloat_Check(item) || PyLong_Check(item);
|
||||
#else
|
||||
return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CheckOneInput(PyObject* item) {
|
||||
if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
|
||||
PyArray_Check(item) || IsNumberType(item)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Sequences are not properly handled. Sequences with purely python numeric
|
||||
// types work, but sequences with mixes of EagerTensors and python numeric
|
||||
// types don't work.
|
||||
// TODO(nareshmodi): fix
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CheckInputsOk(PyObject* seq, int start_index,
|
||||
const tensorflow::OpDef& op_def) {
|
||||
for (int i = 0; i < op_def.input_arg_size(); i++) {
|
||||
@ -1419,8 +1521,7 @@ bool CheckInputsOk(PyObject* seq, int start_index,
|
||||
}
|
||||
for (Py_ssize_t j = 0; j < PySequence_Fast_GET_SIZE(item); j++) {
|
||||
PyObject* inner_item = PySequence_Fast_GET_ITEM(item, j);
|
||||
if (!EagerTensor_CheckExact(inner_item) &&
|
||||
!CheckResourceVariable(inner_item)) {
|
||||
if (!CheckOneInput(inner_item)) {
|
||||
VLOG(1)
|
||||
<< "Falling back to slow path for Op \"" << op_def.name()
|
||||
<< "\", Input \"" << op_def.input_arg(i).name() << "\", Index "
|
||||
@ -1430,7 +1531,7 @@ bool CheckInputsOk(PyObject* seq, int start_index,
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (!EagerTensor_CheckExact(item) && !CheckResourceVariable(item)) {
|
||||
} else if (!CheckOneInput(item)) {
|
||||
VLOG(1)
|
||||
<< "Falling back to slow path for Op \"" << op_def.name()
|
||||
<< "\", Input \"" << op_def.input_arg(i).name()
|
||||
@ -1443,6 +1544,52 @@ bool CheckInputsOk(PyObject* seq, int start_index,
|
||||
return true;
|
||||
}
|
||||
|
||||
PyObject* MaybeGetDType(PyObject* item) {
|
||||
if (EagerTensor_CheckExact(item)) {
|
||||
tensorflow::Safe_PyObjectPtr py_dtype(
|
||||
PyObject_GetAttrString(item, "dtype"));
|
||||
return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
|
||||
}
|
||||
|
||||
if (CheckResourceVariable(item)) {
|
||||
tensorflow::Safe_PyObjectPtr py_dtype(
|
||||
PyObject_GetAttrString(item, "_dtype"));
|
||||
return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* MaybeGetDTypeForAttr(const string& attr,
|
||||
FastPathOpExecInfo* op_exec_info) {
|
||||
auto cached_it = op_exec_info->cached_dtypes.find(attr);
|
||||
if (cached_it != op_exec_info->cached_dtypes.end()) {
|
||||
return GetPythonObjectFromInt(cached_it->second);
|
||||
}
|
||||
|
||||
auto it = op_exec_info->attr_to_inputs_map->find(attr);
|
||||
if (it == op_exec_info->attr_to_inputs_map->end()) {
|
||||
// No other inputs - this should never happen.
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
for (const auto& input_info : it->second) {
|
||||
PyObject* item = PyTuple_GET_ITEM(
|
||||
op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
|
||||
if (input_info.is_list) {
|
||||
for (int i = 0; i < PySequence_Fast_GET_SIZE(item); i++) {
|
||||
auto* dtype = MaybeGetDType(PySequence_Fast_GET_ITEM(item, i));
|
||||
if (dtype != nullptr) return dtype;
|
||||
}
|
||||
} else {
|
||||
auto* dtype = MaybeGetDType(item);
|
||||
if (dtype != nullptr) return dtype;
|
||||
}
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
bool OpDoesntRequireOutput(const string& op_name) {
|
||||
static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
|
||||
new tensorflow::gtl::FlatSet<string>({
|
||||
@ -1668,23 +1815,80 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
||||
// i) input is an EagerTensor
|
||||
// ii) input is a ResourceVariable - in this case, the is_variable param is set
|
||||
// to true.
|
||||
bool ConvertToTensor(const FastPathOpExecInfo& op_exec_info, PyObject* input,
|
||||
tensorflow::Safe_PyObjectPtr* output_handle,
|
||||
TF_Status* status) {
|
||||
if (CheckResourceVariable(input)) {
|
||||
//
|
||||
// NOTE: dtype_hint_getter must *always* return a PyObject that can be
|
||||
// decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
|
||||
// increfs Py_None).
|
||||
bool ConvertToTensor(
|
||||
const FastPathOpExecInfo& op_exec_info, PyObject* input,
|
||||
tensorflow::Safe_PyObjectPtr* output_handle,
|
||||
// This gets a hint for this particular input.
|
||||
const std::function<PyObject*()>& dtype_hint_getter,
|
||||
// This sets the dtype after conversion is complete.
|
||||
const std::function<void(const TF_DataType& dtype)>& dtype_setter,
|
||||
TF_Status* status) {
|
||||
if (EagerTensor_CheckExact(input)) {
|
||||
Py_INCREF(input);
|
||||
output_handle->reset(input);
|
||||
return true;
|
||||
} else if (CheckResourceVariable(input)) {
|
||||
return ReadVariableOp(op_exec_info, input, output_handle, status);
|
||||
}
|
||||
|
||||
Py_INCREF(input);
|
||||
output_handle->reset(input);
|
||||
// The hint comes from a supposedly similarly typed tensor.
|
||||
tensorflow::Safe_PyObjectPtr dtype_hint(dtype_hint_getter());
|
||||
if (PyErr_Occurred()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::Safe_TFE_TensorHandlePtr handle =
|
||||
tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
|
||||
tensorflow::ConvertToEagerTensor(input, dtype_hint.get())));
|
||||
if (handle == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to convert value to tensor");
|
||||
return false;
|
||||
}
|
||||
|
||||
int desired_dtype = -1;
|
||||
if (dtype_hint.get() != Py_None) {
|
||||
if (!ParseTypeValue("", dtype_hint.get(), status, &desired_dtype)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Expecting a DataType value for dtype. Got ",
|
||||
Py_TYPE(dtype_hint.get())->tp_name);
|
||||
}
|
||||
}
|
||||
|
||||
TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
|
||||
handle = tensorflow::make_safe(
|
||||
tensorflow::EagerCast(op_exec_info.ctx, handle.get(), handle_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), status));
|
||||
if (!status->status.ok()) return false;
|
||||
|
||||
handle_dtype = TFE_TensorHandleDataType(handle.get());
|
||||
}
|
||||
|
||||
if (handle_dtype != TF_INT32) {
|
||||
// Note that this is a shallow copy and will share the underlying buffer
|
||||
// if copying to the same device.
|
||||
handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
|
||||
handle.get(), op_exec_info.ctx, op_exec_info.device_name, status));
|
||||
if (!status->status.ok()) return false;
|
||||
}
|
||||
|
||||
output_handle->reset(EagerTensorFromHandle(handle.release()));
|
||||
|
||||
dtype_setter(handle_dtype);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Adds input and type attr to the op, and to the list of flattened
|
||||
// inputs/attrs.
|
||||
bool AddInputToOp(const FastPathOpExecInfo& op_exec_info, PyObject* input,
|
||||
const tensorflow::OpDef::ArgDef* input_arg,
|
||||
bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
|
||||
const bool add_type_attr,
|
||||
const tensorflow::OpDef::ArgDef& input_arg,
|
||||
std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
|
||||
std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
|
||||
TFE_Op* op, TF_Status* status) {
|
||||
@ -1693,18 +1897,30 @@ bool AddInputToOp(const FastPathOpExecInfo& op_exec_info, PyObject* input,
|
||||
// out of scope in this function.
|
||||
tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
|
||||
|
||||
if (!ConvertToTensor(op_exec_info, input, &py_eager_tensor, status)) {
|
||||
if (!ConvertToTensor(
|
||||
*op_exec_info, input, &py_eager_tensor,
|
||||
[&]() {
|
||||
if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
|
||||
return GetPythonObjectFromInt(input_arg.type());
|
||||
}
|
||||
return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
|
||||
},
|
||||
[&](const TF_DataType dtype) {
|
||||
op_exec_info->cached_dtypes[input_arg.type_attr()] =
|
||||
static_cast<tensorflow::DataType>(dtype);
|
||||
},
|
||||
status)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
|
||||
|
||||
if (input_arg != nullptr && !input_arg->type_attr().empty()) {
|
||||
if (add_type_attr && !input_arg.type_attr().empty()) {
|
||||
auto dtype = TFE_TensorHandleDataType(input_handle);
|
||||
TFE_OpSetAttrType(op, input_arg->type_attr().data(), dtype);
|
||||
TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
|
||||
if (flattened_attrs != nullptr) {
|
||||
flattened_attrs->emplace_back(
|
||||
GetPythonObjectFromString(input_arg->type_attr().data()));
|
||||
GetPythonObjectFromString(input_arg.type_attr().data()));
|
||||
flattened_attrs->emplace_back(PyLong_FromLong(dtype));
|
||||
}
|
||||
}
|
||||
@ -1844,6 +2060,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
|
||||
op_exec_info.ctx = reinterpret_cast<TFE_Context*>(
|
||||
PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
|
||||
op_exec_info.args = args;
|
||||
|
||||
if (op_exec_info.ctx == nullptr) {
|
||||
// The context hasn't been initialized. It will be in the slow path.
|
||||
@ -1892,6 +2109,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
op_exec_info.attr_to_inputs_map = GetAttrToInputsMap(*op_def);
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(op_exec_info.ctx, op_def->name().c_str(), status);
|
||||
auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
|
||||
@ -1986,17 +2205,16 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
|
||||
if (len > 0) {
|
||||
// First item adds the type attr.
|
||||
if (!AddInputToOp(op_exec_info, PySequence_Fast_GET_ITEM(input, 0),
|
||||
&input_arg, flattened_attrs.get(),
|
||||
if (!AddInputToOp(&op_exec_info, PySequence_Fast_GET_ITEM(input, 0),
|
||||
true, input_arg, flattened_attrs.get(),
|
||||
flattened_inputs.get(), op, status)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (Py_ssize_t j = 1; j < len; j++) {
|
||||
// Since the list is homogeneous, we don't need to re-add the attr.
|
||||
if (!AddInputToOp(op_exec_info, PySequence_Fast_GET_ITEM(input, j),
|
||||
nullptr /* input_arg */,
|
||||
nullptr /* flattened_attrs */,
|
||||
if (!AddInputToOp(&op_exec_info, PySequence_Fast_GET_ITEM(input, j),
|
||||
false, input_arg, nullptr /* flattened_attrs */,
|
||||
flattened_inputs.get(), op, status)) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -2018,7 +2236,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
PyObject* py_input = PySequence_Fast_GET_ITEM(input, j);
|
||||
tensorflow::Safe_PyObjectPtr py_eager_tensor;
|
||||
if (!ConvertToTensor(op_exec_info, py_input, &py_eager_tensor,
|
||||
status)) {
|
||||
[]() { Py_RETURN_NONE; },
|
||||
[](const TF_DataType& dtype) {}, status)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -2048,8 +2267,9 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
attr_list_sizes[attr_name] = len;
|
||||
} else {
|
||||
// The item is a single item.
|
||||
if (!AddInputToOp(op_exec_info, input, &input_arg, flattened_attrs.get(),
|
||||
flattened_inputs.get(), op, status)) {
|
||||
if (!AddInputToOp(&op_exec_info, input, true, input_arg,
|
||||
flattened_attrs.get(), flattened_inputs.get(), op,
|
||||
status)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -278,14 +278,9 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"tensor_list argument must be a list. Got \"EagerTensor\""):
|
||||
r"tensors argument must be a list or a tuple. Got \"EagerTensor\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"tensor_list argument must be a list. Got \"tuple\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice((t1,), -2)
|
||||
|
||||
def testNegativeSliceDim(self):
|
||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
|
||||
|
@ -1385,6 +1385,22 @@ def register_tensor_conversion_function(base_type,
|
||||
if not callable(conversion_func):
|
||||
raise TypeError("conversion_func must be callable.")
|
||||
|
||||
# context._context is checked so that we don't inadvertently create it.
|
||||
# This is because enable_eager_execution will fail when called from the main
|
||||
# function if the context._context is already created, and the
|
||||
# register_tensor_conversion_function calls happen when the module is
|
||||
# imported.
|
||||
if context._context is not None and context.executing_eagerly(
|
||||
) and isinstance(base_type, six.integer_types + (
|
||||
float,
|
||||
np.ndarray,
|
||||
)):
|
||||
# TODO(nareshmodi): consider setting a context variable which disables the
|
||||
# fastpath instead.
|
||||
raise TypeError(
|
||||
"Cannot register conversions for numpy arrays, python number types "
|
||||
"when executing eagerly.")
|
||||
|
||||
try:
|
||||
funcs_at_priority = _tensor_conversion_func_registry[priority]
|
||||
except KeyError:
|
||||
|
Loading…
Reference in New Issue
Block a user