Add string support to generate_examples.py
PiperOrigin-RevId: 233127983
This commit is contained in:
parent
046f1b74a0
commit
4d6fd23bcb
tensorflow/lite
python/interpreter_wrapper
testing
@ -160,5 +160,21 @@ bool FillStringBufferWithPyArray(PyObject* value,
|
||||
return false;
|
||||
}
|
||||
|
||||
int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyBytes_AsStringAndSize(obj, data, length);
|
||||
#else
|
||||
return PyString_AsStringAndSize(obj, data, length);
|
||||
#endif
|
||||
}
|
||||
|
||||
PyObject* ConvertToPyString(const char* data, size_t length) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyBytes_FromStringAndSize(data, length);
|
||||
#else
|
||||
return PyString_FromStringAndSize(data, length);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace python_utils
|
||||
} // namespace tflite
|
||||
|
@ -23,6 +23,10 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace python_utils {
|
||||
|
||||
struct PyDecrefDeleter {
|
||||
void operator()(PyObject* p) const { Py_DECREF(p); }
|
||||
};
|
||||
|
||||
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
|
||||
|
||||
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
|
||||
@ -30,6 +34,9 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
|
||||
bool FillStringBufferWithPyArray(PyObject* value,
|
||||
DynamicBuffer* dynamic_buffer);
|
||||
|
||||
int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length);
|
||||
PyObject* ConvertToPyString(const char* data, size_t length);
|
||||
|
||||
} // namespace python_utils
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
||||
|
@ -10,6 +10,7 @@ load(
|
||||
"generated_test_models_all",
|
||||
)
|
||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_binary",
|
||||
@ -78,6 +79,7 @@ py_binary(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":generate_examples_report",
|
||||
":string_util_wrapper",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:graph_util",
|
||||
"//third_party/py/numpy",
|
||||
@ -392,4 +394,29 @@ tf_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "string_util_lib",
|
||||
srcs = ["string_util.cc"],
|
||||
hdrs = ["string_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:numpy",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
name = "string_util_wrapper",
|
||||
srcs = [
|
||||
"string_util.i",
|
||||
],
|
||||
deps = [
|
||||
":string_util_lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
||||
|
@ -36,6 +36,7 @@ import operator
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
@ -52,6 +53,7 @@ import tensorflow as tf
|
||||
from google.protobuf import text_format
|
||||
# TODO(aselle): switch to TensorFlow's resource_loader
|
||||
from tensorflow.lite.testing import generate_examples_report as report_lib
|
||||
from tensorflow.lite.testing import string_util_wrapper
|
||||
from tensorflow.python.framework import graph_util as tf_graph_util
|
||||
from tensorflow.python.ops import rnn
|
||||
|
||||
@ -163,6 +165,16 @@ def toco_options(data_types,
|
||||
return s
|
||||
|
||||
|
||||
def format_result(t):
|
||||
"""Convert a tensor to a format that can be used in test specs."""
|
||||
if np.issubdtype(t.dtype, np.number):
|
||||
# Output 9 digits after the point to ensure the precision is good enough.
|
||||
values = ["{:.9f}".format(value) for value in list(t.flatten())]
|
||||
return ",".join(values)
|
||||
else:
|
||||
return string_util_wrapper.SerializeAsHexString(t.flatten())
|
||||
|
||||
|
||||
def write_examples(fp, examples):
|
||||
"""Given a list `examples`, write a text format representation.
|
||||
|
||||
@ -179,9 +191,7 @@ def write_examples(fp, examples):
|
||||
"""Write tensor in file format supported by TFLITE example."""
|
||||
fp.write("dtype,%s\n" % x.dtype)
|
||||
fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
|
||||
# Output 9 digits after the point to ensure the precision is good enough.
|
||||
values = ["{:.9f}".format(value) for value in list(x.flatten())]
|
||||
fp.write("values," + ",".join(values) + "\n")
|
||||
fp.write("values," + format_result(x) + "\n")
|
||||
|
||||
fp.write("test_cases,%d\n" % len(examples))
|
||||
for example in examples:
|
||||
@ -214,11 +224,9 @@ def write_test_cases(fp, model_name, examples):
|
||||
fp.write("invoke {\n")
|
||||
|
||||
for t in example["inputs"]:
|
||||
values = ["{:.9f}".format(value) for value in list(t.flatten())]
|
||||
fp.write(" input: \"" + ",".join(values) + "\"\n")
|
||||
fp.write(" input: \"" + format_result(t) + "\"\n")
|
||||
for t in example["outputs"]:
|
||||
values = ["{:.9f}".format(value) for value in list(t.flatten())]
|
||||
fp.write(" output: \"" + ",".join(values) + "\"\n")
|
||||
fp.write(" output: \"" + format_result(t) + "\"\n")
|
||||
fp.write("}\n")
|
||||
|
||||
|
||||
@ -230,6 +238,7 @@ _TF_TYPE_INFO = {
|
||||
tf.int16: (np.int16, "QUANTIZED_INT16"),
|
||||
tf.int64: (np.int64, "INT64"),
|
||||
tf.bool: (np.bool, "BOOL"),
|
||||
tf.string: (np.string_, "STRING"),
|
||||
}
|
||||
|
||||
|
||||
@ -245,6 +254,10 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
|
||||
value = np.random.randint(min_value, max_value+1, shape)
|
||||
elif dtype == tf.bool:
|
||||
value = np.random.choice([True, False], size=shape)
|
||||
elif dtype == np.string_:
|
||||
# Not the best strings, but they will do for some basic testing.
|
||||
letters = list(string.ascii_uppercase)
|
||||
return np.random.choice(letters, size=shape).astype(dtype)
|
||||
return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
|
||||
dtype)
|
||||
|
||||
@ -1294,16 +1307,25 @@ def make_squared_difference_tests(zip_path):
|
||||
def make_gather_tests(zip_path):
|
||||
"""Make a set of tests to do gather."""
|
||||
|
||||
test_parameters = [{
|
||||
# TODO(mgubin): add string tests when they are supported by Toco.
|
||||
# TODO(mgubin): add tests for Nd indices when they are supported by
|
||||
# TfLite.
|
||||
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"params_shape": [[10], [1, 2, 20]],
|
||||
"indices_dtype": [tf.int32, tf.int64],
|
||||
"indices_shape": [[3], [5]],
|
||||
"axis": [-1, 0, 1],
|
||||
}]
|
||||
test_parameters = [
|
||||
{
|
||||
# TODO(b/110347007): add tests for Nd indices when they are supported
|
||||
# by TfLite.
|
||||
"params_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"params_shape": [[10], [1, 2, 20]],
|
||||
"indices_dtype": [tf.int32, tf.int64],
|
||||
"indices_shape": [[3], [5]],
|
||||
"axis": [-1, 0, 1],
|
||||
},
|
||||
{
|
||||
# TODO(b/123895910): add Nd support for strings.
|
||||
"params_dtype": [tf.string],
|
||||
"params_shape": [[8]],
|
||||
"indices_dtype": [tf.int32],
|
||||
"indices_shape": [[3]],
|
||||
"axis": [0],
|
||||
}
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the gather op testing graph."""
|
||||
|
45
tensorflow/lite/testing/string_util.cc
Normal file
45
tensorflow/lite/testing/string_util.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/testing/string_util.h"
|
||||
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace python {
|
||||
|
||||
PyObject* SerializeAsHexString(PyObject* value) {
|
||||
DynamicBuffer dynamic_buffer;
|
||||
if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
char* char_buffer = nullptr;
|
||||
size_t size = dynamic_buffer.WriteToBuffer(&char_buffer);
|
||||
string s = absl::BytesToHexString({char_buffer, size});
|
||||
free(char_buffer);
|
||||
|
||||
return python_utils::ConvertToPyString(s.data(), s.size());
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
33
tensorflow/lite/testing/string_util.h
Normal file
33
tensorflow/lite/testing/string_util.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TESTING_STRING_UTIL_H_
|
||||
#define TENSORFLOW_LITE_TESTING_STRING_UTIL_H_
|
||||
|
||||
#include <Python.h>
|
||||
#include <string>
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace python {
|
||||
|
||||
// Take a python string array, convert it to TF Lite dynamic buffer format and
|
||||
// serialize it as a HexString.
|
||||
PyObject* SerializeAsHexString(PyObject* value);
|
||||
|
||||
} // namespace python
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TESTING_STRING_UTIL_H_
|
31
tensorflow/lite/testing/string_util.i
Normal file
31
tensorflow/lite/testing/string_util.i
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%{
|
||||
|
||||
#define SWIG_FILE_WITH_INIT
|
||||
#include "tensorflow/lite/testing/string_util.h"
|
||||
|
||||
%}
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace python {
|
||||
|
||||
PyObject* SerializeAsHexString(PyObject* string_tensor);
|
||||
|
||||
} // namespace python
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user