Add string support to generate_examples.py

PiperOrigin-RevId: 233127983
This commit is contained in:
A. Unique TensorFlower 2019-02-08 14:33:47 -08:00 committed by TensorFlower Gardener
parent 046f1b74a0
commit 4d6fd23bcb
7 changed files with 198 additions and 17 deletions

View File

@ -160,5 +160,21 @@ bool FillStringBufferWithPyArray(PyObject* value,
return false;
}
int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) {
#if PY_MAJOR_VERSION >= 3
return PyBytes_AsStringAndSize(obj, data, length);
#else
return PyString_AsStringAndSize(obj, data, length);
#endif
}
PyObject* ConvertToPyString(const char* data, size_t length) {
#if PY_MAJOR_VERSION >= 3
return PyBytes_FromStringAndSize(data, length);
#else
return PyString_FromStringAndSize(data, length);
#endif
}
} // namespace python_utils
} // namespace tflite

View File

@ -23,6 +23,10 @@ limitations under the License.
namespace tflite {
namespace python_utils {
struct PyDecrefDeleter {
void operator()(PyObject* p) const { Py_DECREF(p); }
};
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
@ -30,6 +34,9 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
bool FillStringBufferWithPyArray(PyObject* value,
DynamicBuffer* dynamic_buffer);
int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length);
PyObject* ConvertToPyString(const char* data, size_t length);
} // namespace python_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_

View File

@ -10,6 +10,7 @@ load(
"generated_test_models_all",
)
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_binary",
@ -78,6 +79,7 @@ py_binary(
srcs_version = "PY2AND3",
deps = [
":generate_examples_report",
":string_util_wrapper",
"//tensorflow:tensorflow_py",
"//tensorflow/python:graph_util",
"//third_party/py/numpy",
@ -392,4 +394,29 @@ tf_cc_binary(
],
)
cc_library(
name = "string_util_lib",
srcs = ["string_util.cc"],
hdrs = ["string_util.h"],
deps = [
"//tensorflow/lite:string_util",
"//tensorflow/lite/python/interpreter_wrapper:numpy",
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/strings",
],
)
tf_py_wrap_cc(
name = "string_util_wrapper",
srcs = [
"string_util.i",
],
deps = [
":string_util_lib",
"//third_party/python_runtime:headers",
],
)
tflite_portable_test_suite()

View File

@ -36,6 +36,7 @@ import operator
import os
import random
import re
import string
import sys
import tempfile
import traceback
@ -52,6 +53,7 @@ import tensorflow as tf
from google.protobuf import text_format
# TODO(aselle): switch to TensorFlow's resource_loader
from tensorflow.lite.testing import generate_examples_report as report_lib
from tensorflow.lite.testing import string_util_wrapper
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.ops import rnn
@ -163,6 +165,16 @@ def toco_options(data_types,
return s
def format_result(t):
"""Convert a tensor to a format that can be used in test specs."""
if np.issubdtype(t.dtype, np.number):
# Output 9 digits after the point to ensure the precision is good enough.
values = ["{:.9f}".format(value) for value in list(t.flatten())]
return ",".join(values)
else:
return string_util_wrapper.SerializeAsHexString(t.flatten())
def write_examples(fp, examples):
"""Given a list `examples`, write a text format representation.
@ -179,9 +191,7 @@ def write_examples(fp, examples):
"""Write tensor in file format supported by TFLITE example."""
fp.write("dtype,%s\n" % x.dtype)
fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
# Output 9 digits after the point to ensure the precision is good enough.
values = ["{:.9f}".format(value) for value in list(x.flatten())]
fp.write("values," + ",".join(values) + "\n")
fp.write("values," + format_result(x) + "\n")
fp.write("test_cases,%d\n" % len(examples))
for example in examples:
@ -214,11 +224,9 @@ def write_test_cases(fp, model_name, examples):
fp.write("invoke {\n")
for t in example["inputs"]:
values = ["{:.9f}".format(value) for value in list(t.flatten())]
fp.write(" input: \"" + ",".join(values) + "\"\n")
fp.write(" input: \"" + format_result(t) + "\"\n")
for t in example["outputs"]:
values = ["{:.9f}".format(value) for value in list(t.flatten())]
fp.write(" output: \"" + ",".join(values) + "\"\n")
fp.write(" output: \"" + format_result(t) + "\"\n")
fp.write("}\n")
@ -230,6 +238,7 @@ _TF_TYPE_INFO = {
tf.int16: (np.int16, "QUANTIZED_INT16"),
tf.int64: (np.int64, "INT64"),
tf.bool: (np.bool, "BOOL"),
tf.string: (np.string_, "STRING"),
}
@ -245,6 +254,10 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
value = np.random.randint(min_value, max_value+1, shape)
elif dtype == tf.bool:
value = np.random.choice([True, False], size=shape)
elif dtype == np.string_:
# Not the best strings, but they will do for some basic testing.
letters = list(string.ascii_uppercase)
return np.random.choice(letters, size=shape).astype(dtype)
return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
dtype)
@ -1294,16 +1307,25 @@ def make_squared_difference_tests(zip_path):
def make_gather_tests(zip_path):
"""Make a set of tests to do gather."""
test_parameters = [{
# TODO(mgubin): add string tests when they are supported by Toco.
# TODO(mgubin): add tests for Nd indices when they are supported by
# TfLite.
"params_dtype": [tf.float32, tf.int32, tf.int64],
"params_shape": [[10], [1, 2, 20]],
"indices_dtype": [tf.int32, tf.int64],
"indices_shape": [[3], [5]],
"axis": [-1, 0, 1],
}]
test_parameters = [
{
# TODO(b/110347007): add tests for Nd indices when they are supported
# by TfLite.
"params_dtype": [tf.float32, tf.int32, tf.int64],
"params_shape": [[10], [1, 2, 20]],
"indices_dtype": [tf.int32, tf.int64],
"indices_shape": [[3], [5]],
"axis": [-1, 0, 1],
},
{
# TODO(b/123895910): add Nd support for strings.
"params_dtype": [tf.string],
"params_shape": [[8]],
"indices_dtype": [tf.int32],
"indices_shape": [[3]],
"axis": [0],
}
]
def build_graph(parameters):
"""Build the gather op testing graph."""

View File

@ -0,0 +1,45 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/lite/testing/string_util.h"
#include "absl/strings/escaping.h"
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace testing {
namespace python {
PyObject* SerializeAsHexString(PyObject* value) {
DynamicBuffer dynamic_buffer;
if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
return nullptr;
}
char* char_buffer = nullptr;
size_t size = dynamic_buffer.WriteToBuffer(&char_buffer);
string s = absl::BytesToHexString({char_buffer, size});
free(char_buffer);
return python_utils::ConvertToPyString(s.data(), s.size());
}
} // namespace python
} // namespace testing
} // namespace tflite

View File

@ -0,0 +1,33 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TESTING_STRING_UTIL_H_
#define TENSORFLOW_LITE_TESTING_STRING_UTIL_H_
#include <Python.h>
#include <string>
namespace tflite {
namespace testing {
namespace python {
// Take a python string array, convert it to TF Lite dynamic buffer format and
// serialize it as a HexString.
PyObject* SerializeAsHexString(PyObject* value);
} // namespace python
} // namespace testing
} // namespace tflite
#endif // TENSORFLOW_LITE_TESTING_STRING_UTIL_H_

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%{
#define SWIG_FILE_WITH_INIT
#include "tensorflow/lite/testing/string_util.h"
%}
namespace tflite {
namespace testing {
namespace python {
PyObject* SerializeAsHexString(PyObject* string_tensor);
} // namespace python
} // namespace testing
} // namespace tflite