Tool for generating on-device image test inputs

PiperOrigin-RevId: 292015086
Change-Id: I6dc3b4dddcaadf87027fb51ed31b691ddfd40e7e
This commit is contained in:
Pete Warden 2020-01-28 14:33:36 -08:00 committed by TensorFlower Gardener
parent c8357711e3
commit 9c6e8d5d19
3 changed files with 228 additions and 0 deletions

View File

@ -38,6 +38,34 @@ py_test(
],
)
py_binary(
name = "convert_image_to_csv",
srcs = ["convert_image_to_csv.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python:platform",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_test(
name = "convert_image_to_csv_test",
srcs = ["convert_image_to_csv_test.py"],
data = ["//tensorflow/core:image_testdata"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":convert_image_to_csv",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
],
)
tf_cc_binary(
name = "generate_op_registrations",
srcs = ["gen_op_registration_main.cc"],

View File

@ -0,0 +1,115 @@
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""This tool converts an image file into a CSV data array.
Designed to help create test inputs that can be shared between Python and
on-device test cases to investigate accuracy issues.
Example usage:
python convert_image_to_csv.py some_image.jpg --width=16 --height=20 \
--want_grayscale
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework.errors_impl import NotFoundError
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import app
def get_image(width, height, want_grayscale, filepath):
"""Returns an image loaded into an np.ndarray with dims [height, width, (3 or 1)].
Args:
width: Width to rescale the image to.
height: Height to rescale the image to.
want_grayscale: Whether the result should be converted to grayscale.
filepath: Path of the image file..
Returns:
np.ndarray of shape (height, width, channels) where channels is 1 if
want_grayscale is true, otherwise 3.
"""
with ops.Graph().as_default():
with session.Session():
file_data = io_ops.read_file(filepath)
channels = 1 if want_grayscale else 3
image_tensor = image_ops.decode_image(file_data,
channels=channels).eval()
resized_tensor = image_ops.resize_images_v2(
image_tensor, (height, width)).eval()
return resized_tensor
def array_to_int_csv(array_data):
"""Converts all elements in a numerical array to a comma-separated string.
Args:
array_data: Numerical array to convert.
Returns:
String containing array values as integers, separated by commas.
"""
flattened_array = array_data.flatten()
array_as_strings = [item.astype(int).astype(str) for item in flattened_array]
return ','.join(array_as_strings)
def run_main(_):
"""Application run loop."""
parser = argparse.ArgumentParser(
description='Loads JPEG or PNG input files, resizes them, optionally'
' converts to grayscale, and writes out as comma-separated variables,'
' one image per row.')
parser.add_argument(
'image_file_names',
type=str,
nargs='+',
help='List of paths to the input images.')
parser.add_argument(
'--width', type=int, default=96, help='Width to scale images to.')
parser.add_argument(
'--height', type=int, default=96, help='Height to scale images to.')
parser.add_argument(
'--want_grayscale',
action='store_true',
help='Whether to convert the image to monochrome.')
args = parser.parse_args()
for image_file_name in args.image_file_names:
try:
image_data = get_image(args.width, args.height, args.want_grayscale,
image_file_name)
print(array_to_int_csv(image_data))
except NotFoundError:
sys.stderr.write('Image file not found at {0}\n'.format(image_file_name))
sys.exit(1)
def main():
app.run(main=run_main, argv=sys.argv[:1])
if __name__ == '__main__':
main()

View File

@ -0,0 +1,85 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests image file conversion utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.lite.tools import convert_image_to_csv
from tensorflow.python.framework import test_util
from tensorflow.python.framework.errors_impl import NotFoundError
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
PREFIX_PATH = resource_loader.get_path_to_datafile("../../core/lib/")
class ConvertImageToCsvTest(test_util.TensorFlowTestCase):
def testGetImageRaisesMissingFile(self):
image_path = os.path.join(PREFIX_PATH, "jpeg", "testdata", "no_such.jpg")
with self.assertRaises(NotFoundError):
_ = convert_image_to_csv.get_image(64, 96, False, image_path)
def testGetImageSizeIsCorrect(self):
image_path = os.path.join(PREFIX_PATH, "jpeg", "testdata", "small.jpg")
image_data = convert_image_to_csv.get_image(64, 96, False, image_path)
self.assertEqual((96, 64, 3), image_data.shape)
def testGetImageConvertsToGrayscale(self):
image_path = os.path.join(PREFIX_PATH, "jpeg", "testdata", "medium.jpg")
image_data = convert_image_to_csv.get_image(40, 20, True, image_path)
self.assertEqual((20, 40, 1), image_data.shape)
def testGetImageCanLoadPng(self):
image_path = os.path.join(PREFIX_PATH, "png", "testdata", "lena_rgba.png")
image_data = convert_image_to_csv.get_image(10, 10, False, image_path)
self.assertEqual((10, 10, 3), image_data.shape)
def testGetImageConvertsGrayscaleToColor(self):
image_path = os.path.join(PREFIX_PATH, "png", "testdata", "lena_gray.png")
image_data = convert_image_to_csv.get_image(23, 19, False, image_path)
self.assertEqual((19, 23, 3), image_data.shape)
def testGetImageColorValuesInRange(self):
image_path = os.path.join(PREFIX_PATH, "jpeg", "testdata", "small.jpg")
image_data = convert_image_to_csv.get_image(47, 31, False, image_path)
self.assertLessEqual(0, np.min(image_data))
self.assertGreaterEqual(255, np.max(image_data))
def testGetImageGrayscaleValuesInRange(self):
image_path = os.path.join(PREFIX_PATH, "jpeg", "testdata", "small.jpg")
image_data = convert_image_to_csv.get_image(27, 33, True, image_path)
self.assertLessEqual(0, np.min(image_data))
self.assertGreaterEqual(255, np.max(image_data))
def testArrayToIntCsv(self):
csv_string = convert_image_to_csv.array_to_int_csv(
np.array([[1, 2], [3, 4]]))
self.assertEqual("1,2,3,4", csv_string)
def testArrayToIntCsvRounding(self):
csv_string = convert_image_to_csv.array_to_int_csv(
np.array([[1.0, 2.0], [3.0, 4.0]]))
self.assertEqual("1,2,3,4", csv_string)
if __name__ == "__main__":
test.main()