diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index aa0ffe90e1a..02b8b80be90 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -13,7 +13,6 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", - "//tensorflow/python:util", "//third_party/py/numpy", ], ) diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 7ec56a21c9f..b21779226f6 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -91,6 +91,41 @@ class InterpreterTest(test_util.TensorFlowTestCase): output_data = interpreter.get_tensor(output_details[0]['index']) self.assertTrue((expected_output == output_data).all()) + def testString(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/gather_string.tflite')) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.string_, input_details[0]['dtype']) + self.assertTrue(([10] == input_details[0]['shape']).all()) + self.assertEqual((0.0, 0), input_details[0]['quantization']) + self.assertEqual('indices', input_details[1]['name']) + self.assertEqual(np.int64, input_details[1]['dtype']) + self.assertTrue(([3] == input_details[1]['shape']).all()) + self.assertEqual((0.0, 0), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.string_, output_details[0]['dtype']) + self.assertTrue(([3] == output_details[0]['shape']).all()) + self.assertEqual((0.0, 0), output_details[0]['quantization']) + + test_input = np.array([1, 2, 3], dtype=np.int64) + interpreter.set_tensor(input_details[1]['index'], test_input) + + test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']) + expected_output = np.array([b'b', b'c', b'd']) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) + class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 6de6fb48f78..6ec7ce497a5 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -6,14 +6,26 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +cc_library( + name = "numpy", + srcs = ["numpy.cc"], + hdrs = ["numpy.h"], + deps = [ + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", + ], +) + cc_library( name = "interpreter_wrapper_lib", srcs = ["interpreter_wrapper.cc"], hdrs = ["interpreter_wrapper.h"], deps = [ + ":numpy", ":python_error_reporter", ":python_utils", "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", "//tensorflow/lite/kernels:builtin_ops", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", @@ -36,7 +48,9 @@ cc_library( srcs = ["python_utils.cc"], hdrs = ["python_utils.h"], deps = [ + ":numpy", "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 9ccaabbfe97..41cebf867ec 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -21,16 +21,10 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" - -// Disallow Numpy 1.7 deprecated symbols. -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - -#include <Python.h> - -#include "numpy/arrayobject.h" -#include "numpy/ufuncobject.h" +#include "tensorflow/lite/string_util.h" #if PY_MAJOR_VERSION >= 3 #define PY_TO_CPPSTRING PyBytes_AsStringAndSize @@ -64,12 +58,6 @@ namespace interpreter_wrapper { namespace { -// Calls PyArray's initialization to initialize all the API pointers. Note that -// this usage implies only this translation unit can use the pointers. See -// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend -// this further. -void ImportNumpy() { import_array1(); } - std::unique_ptr<tflite::Interpreter> CreateInterpreter( const tflite::FlatBufferModel* model, const tflite::ops::builtin::BuiltinOpResolver& resolver) { @@ -77,7 +65,7 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( return nullptr; } - ImportNumpy(); + ::tflite::python::ImportNumpy(); std::unique_ptr<tflite::Interpreter> interpreter; if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { @@ -267,7 +255,7 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { } PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); - const TfLiteTensor* tensor = interpreter_->tensor(i); + TfLiteTensor* tensor = interpreter_->tensor(i); if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) { PyErr_Format(PyExc_ValueError, @@ -279,26 +267,41 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { } if (PyArray_NDIM(array) != tensor->dims->size) { - PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch"); + PyErr_Format(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch." + " Got %d" + " but expected %d for input %d.", + PyArray_NDIM(array), tensor->dims->size, i); return nullptr; } for (int j = 0; j < PyArray_NDIM(array); j++) { if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { - PyErr_SetString(PyExc_ValueError, - "Cannot set tensor: Dimension mismatch"); + PyErr_Format(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch." + " Got %ld" + " but expected %d for dimension %d of input %d.", + PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i); return nullptr; } } - size_t size = PyArray_NBYTES(array); - if (size != tensor->bytes) { - PyErr_Format(PyExc_ValueError, - "numpy array had %zu bytes but expected %zu bytes.", size, - tensor->bytes); - return nullptr; + if (tensor->type != kTfLiteString) { + size_t size = PyArray_NBYTES(array); + if (size != tensor->bytes) { + PyErr_Format(PyExc_ValueError, + "numpy array had %zu bytes but expected %zu bytes.", size, + tensor->bytes); + return nullptr; + } + memcpy(tensor->data.raw, PyArray_DATA(array), size); + } else { + DynamicBuffer dynamic_buffer; + if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) { + return nullptr; + } + dynamic_buffer.WriteToTensor(tensor, nullptr); } - memcpy(tensor->data.raw, PyArray_DATA(array), size); Py_RETURN_NONE; } @@ -345,19 +348,51 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); - // Make a buffer copy but we must tell Numpy It owns that data or else - // it will leak. - void* data = malloc(tensor->bytes); - if (!data) { - PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); - return nullptr; + if (tensor->type != kTfLiteString) { + // Make a buffer copy but we must tell Numpy It owns that data or else + // it will leak. + void* data = malloc(tensor->bytes); + if (!data) { + PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); + return nullptr; + } + memcpy(data, tensor->data.raw, tensor->bytes); + PyObject* np_array = + PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); + PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array), + NPY_ARRAY_OWNDATA); + return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); + } else { + // Create a C-order array so the data is contiguous in memory. + const int32_t kCOrder = 0; + PyObject* py_object = + PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder); + + if (py_object == nullptr) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray."); + return nullptr; + } + + PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object); + PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array)); + auto num_strings = GetStringCount(tensor->data.raw); + for (int j = 0; j < num_strings; ++j) { + auto ref = GetString(tensor->data.raw, j); + + PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len); + if (bytes == nullptr) { + Py_DECREF(py_object); + PyErr_Format(PyExc_ValueError, + "Could not create PyBytes from string %d of input %d.", j, + i); + return nullptr; + } + // PyArray_EMPTY produces an array full of Py_None, which we must decref. + Py_DECREF(data[j]); + data[j] = bytes; + } + return py_object; } - memcpy(data, tensor->data.raw, tensor->bytes); - PyObject* np_array = - PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); - PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array), - NPY_ARRAY_OWNDATA); - return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); } PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc new file mode 100644 index 00000000000..ff5403d2a60 --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -0,0 +1,25 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define TFLITE_IMPORT_NUMPY // See numpy.h for explanation. +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" + +namespace tflite { +namespace python { + +void ImportNumpy() { import_array1(); } + +} // namespace python +} // namespace tflite diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.h b/tensorflow/lite/python/interpreter_wrapper/numpy.h new file mode 100644 index 00000000000..a3b013fcb27 --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_ +#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_ + +#ifdef PyArray_Type +#error "Numpy cannot be included before numpy.h." +#endif + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +// To handle PyArray_* calles, numpy defines a static lookup table called +// PyArray_API, or PY_ARRAY_UNIQUE_SYMBOL, if defined. This causes the +// PyArray_* pointers to be different for different translation units, unless +// we take care of selectivel defined NO_IMPORT_ARRAY. +// +// Virtually every usage will define NO_IMPORT_ARRAY, and will have access to +// the lookup table via: +// extern void **PyArray_API; +// In numpy.cc we will define TFLITE_IMPORT_NUMPY, effectively disabling that +// and instead using: +// void **PyArray_API; +// which is initialized when ImportNumpy() is called. +// +// If we don't define PY_ARRAY_UNIQUE_SYMBOL then PyArray_API is a static +// variable, which causes strange crashes when the pointers are used across +// translation unit boundaries. +// +// For mone info see https://sourceforge.net/p/numpy/mailman/message/5700519 +// See also tensorflow/python/lib/core/numpy.h for a similar approach. +#define PY_ARRAY_UNIQUE_SYMBOL _tensorflow_numpy_api +#ifndef TFLITE_IMPORT_NUMPY +#define NO_IMPORT_ARRAY +#endif + +#include <Python.h> + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" + +namespace tflite { +namespace python { + +void ImportNumpy(); + +} // namespace python +} // namespace tflite + +#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_ diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc index 2dc604356ab..a052ca320f9 100644 --- a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc @@ -15,9 +15,19 @@ limitations under the License. #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" +#include <memory> + +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" + namespace tflite { namespace python_utils { +struct PyObjectDereferencer { + void operator()(PyObject* py_object) const { Py_DECREF(py_object); } +}; + +using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>; + int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { switch (tf_lite_type) { case kTfLiteFloat32: @@ -33,7 +43,7 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { case kTfLiteInt64: return NPY_INT64; case kTfLiteString: - return NPY_OBJECT; + return NPY_STRING; case kTfLiteBool: return NPY_BOOL; case kTfLiteComplex64: @@ -73,5 +83,82 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteNoType; } +#if PY_VERSION_HEX >= 0x03030000 +bool FillStringBufferFromPyUnicode(PyObject* value, + DynamicBuffer* dynamic_buffer) { + Py_ssize_t len = -1; + char* buf = PyUnicode_AsUTF8AndSize(value, &len); + if (buf == NULL) { + PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed."); + return false; + } + dynamic_buffer->AddString(buf, len); + return true; +} +#else +bool FillStringBufferFromPyUnicode(PyObject* value, + DynamicBuffer* dynamic_buffer) { + UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value)); + if (!utemp) { + PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed."); + return false; + } + char* buf = nullptr; + Py_ssize_t len = -1; + if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) { + PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed."); + return false; + } + dynamic_buffer->AddString(buf, len); + return true; +} +#endif + +bool FillStringBufferFromPyString(PyObject* value, + DynamicBuffer* dynamic_buffer) { + if (PyUnicode_Check(value)) { + return FillStringBufferFromPyUnicode(value, dynamic_buffer); + } + + char* buf = nullptr; + Py_ssize_t len = -1; + if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) { + PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed."); + return false; + } + dynamic_buffer->AddString(buf, len); + return true; +} + +bool FillStringBufferWithPyArray(PyObject* value, + DynamicBuffer* dynamic_buffer) { + PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value); + switch (PyArray_TYPE(array)) { + case NPY_OBJECT: + case NPY_STRING: + case NPY_UNICODE: { + UniquePyObjectRef iter(PyArray_IterNew(value)); + while (PyArray_ITER_NOTDONE(iter.get())) { + UniquePyObjectRef item(PyArray_GETITEM( + array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get())))); + + if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) { + return false; + } + + PyArray_ITER_NEXT(iter.get()); + } + return true; + } + default: + break; + } + + PyErr_Format(PyExc_ValueError, + "Cannot use numpy array of type %d for string tensor.", + PyArray_TYPE(array)); + return false; +} + } // namespace python_utils } // namespace tflite diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.h b/tensorflow/lite/python/interpreter_wrapper/python_utils.h index 30a44226b8f..5ffd231a892 100644 --- a/tensorflow/lite/python/interpreter_wrapper/python_utils.h +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.h @@ -17,14 +17,8 @@ limitations under the License. #define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ #include "tensorflow/lite/context.h" - -// Disallow Numpy 1.7 deprecated symbols. -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - -#include <Python.h> - -#include "numpy/arrayobject.h" -#include "numpy/ufuncobject.h" +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace python_utils { @@ -33,6 +27,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type); TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array); +bool FillStringBufferWithPyArray(PyObject* value, + DynamicBuffer* dynamic_buffer); + } // namespace python_utils } // namespace tflite #endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index ff91d5e2970..ca6c5b8f13f 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -131,13 +131,13 @@ class FromSessionTest(test_util.TensorFlowTestCase): input_details = interpreter.get_input_details() self.assertEqual(1, len(input_details)) self.assertEqual('Placeholder', input_details[0]['name']) - self.assertEqual(np.object_, input_details[0]['dtype']) + self.assertEqual(np.string_, input_details[0]['dtype']) self.assertTrue(([4] == input_details[0]['shape']).all()) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('Reshape', output_details[0]['name']) - self.assertEqual(np.object_, output_details[0]['dtype']) + self.assertEqual(np.string_, output_details[0]['dtype']) self.assertTrue(([2, 2] == output_details[0]['shape']).all()) # TODO(b/122659643): Test setting/getting string data via the python # interpreter API after support has been added. diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD index 74b573b81f3..8694ebf1f54 100644 --- a/tensorflow/lite/python/optimize/BUILD +++ b/tensorflow/lite/python/optimize/BUILD @@ -13,6 +13,7 @@ cc_library( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/python/interpreter_wrapper:numpy", "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", "//tensorflow/lite/python/interpreter_wrapper:python_utils", "//tensorflow/lite/tools/optimize:calibration_reader", diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index d6fe1fbdad7..21f96f848c5 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -22,20 +22,13 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/tools/optimize/calibration_reader.h" #include "tensorflow/lite/tools/optimize/calibrator.h" #include "tensorflow/lite/tools/optimize/quantize_model.h" -// Disallow Numpy 1.7 deprecated symbols. -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - -#include <Python.h> - -#include "numpy/arrayobject.h" -#include "numpy/ufuncobject.h" - #if PY_MAJOR_VERSION >= 3 #define PY_TO_CPPSTRING PyBytes_AsStringAndSize #define CPP_TO_PYSTRING PyBytes_FromStringAndSize diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD index 4689c318b35..2fa08e53269 100644 --- a/tensorflow/lite/python/testdata/BUILD +++ b/tensorflow/lite/python/testdata/BUILD @@ -32,9 +32,20 @@ tf_to_tflite( ], ) +tf_to_tflite( + name = "gather_string", + src = "gather.pbtxt", + out = "gather_string.tflite", + options = [ + "--input_arrays=input,indices", + "--output_arrays=output", + ], +) + filegroup( name = "interpreter_test_data", srcs = [ + ":gather_string", ":permute_float", ":permute_uint8", ], diff --git a/tensorflow/lite/python/testdata/gather.pbtxt b/tensorflow/lite/python/testdata/gather.pbtxt new file mode 100644 index 00000000000..0b1193c475d --- /dev/null +++ b/tensorflow/lite/python/testdata/gather.pbtxt @@ -0,0 +1,93 @@ +node { + name: "input" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } +} +node { + name: "indices" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + } + } + } +} +node { + name: "axis" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "output" + op: "GatherV2" + input: "input" + input: "indices" + input: "axis" + device: "/device:CPU:0" + attr { + key: "Taxis" + value { + type: DT_INT32 + } + } + attr { + key: "Tindices" + value { + type: DT_INT64 + } + } + attr { + key: "Tparams" + value { + type: DT_STRING + } + } +} +versions { + producer: 27 +}