diff --git a/tensorflow/lite/ios/BUILD.apple b/tensorflow/lite/ios/BUILD.apple index 540c59816f1..2f6b40c5a8a 100644 --- a/tensorflow/lite/ios/BUILD.apple +++ b/tensorflow/lite/ios/BUILD.apple @@ -9,6 +9,13 @@ load( ) load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "pytype_strict_binary") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "pytype_strict_library") +load("//tensorflow:tensorflow.bzl", "py_strict_test") + package( default_visibility = [ "//tensorflow/lite:__subpackages__", @@ -27,6 +34,52 @@ sh_binary( ], ) +pytype_strict_library( + name = "extract_object_files", + srcs = [ + "extract_object_files.py", + ], +) + +pytype_strict_binary( + name = "extract_object_files_main", + srcs = [ + "extract_object_files_main.py", + ], + python_version = "PY3", + srcs_version = "PY3", + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], + deps = [ + ":extract_object_files", + ], +) + +filegroup( + name = "extract_object_files_testdata", + srcs = glob(["testdata/**"]), +) + +py_strict_test( + name = "extract_object_files_test", + srcs = [ + "extract_object_files_test.py", + ], + data = [ + ":extract_object_files_testdata", + ], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":extract_object_files", + "//tensorflow/python/platform", + "//tensorflow/python/platform:client_testlib", + "@absl_py//absl/testing:parameterized", + ], +) + strip_common_include_path_prefix( name = "strip_common_include_path_core", hdr_labels = [ @@ -146,6 +199,8 @@ build_test( "noasan", "nomsan", "notsan", + # TODO(b/176993122): restore once the apple_genrule issue is resolved. + "notap", ], targets = [ ":TensorFlowLiteCCoreML_framework", diff --git a/tensorflow/lite/ios/extract_object_files.py b/tensorflow/lite/ios/extract_object_files.py new file mode 100644 index 00000000000..954c16b3a3c --- /dev/null +++ b/tensorflow/lite/ios/extract_object_files.py @@ -0,0 +1,178 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Module for extracting object files from a compiled archive (.a) file. + +This module provides functionality almost identical to the 'ar -x' command, +which extracts out all object files from a given archive file. This module +assumes the archive is in the BSD variant format used in Apple platforms. + +See: https://en.wikipedia.org/wiki/Ar_(Unix)#BSD_variant + +This extractor has two important differences compared to the 'ar -x' command +shipped with Xcode. + +1. When there are multiple object files with the same name in a given archive, + each file is renamed so that they are all correctly extracted without + overwriting each other. + +2. This module takes the destination directory as an additional parameter. + + Example Usage: + + archive_path = ... + dest_dir = ... + extract_object_files(archive_path, dest_dir) +""" + +import hashlib +import io +import itertools +import os +import struct +from typing import Iterator, Tuple + + +def extract_object_files(archive_file: io.BufferedIOBase, + dest_dir: str) -> None: + """Extracts object files from the archive path to the destination directory. + + Extracts object files from the given BSD variant archive file. The extracted + files are written to the destination directory, which will be created if the + directory does not exist. + + Colliding object file names are automatically renamed upon extraction in order + to avoid unintended overwriting. + + Args: + archive_file: The archive file object pointing at its beginning. + dest_dir: The destination directory path in which the extracted object files + will be written. The directory will be created if it does not exist. + """ + if not os.path.exists(dest_dir): + os.makedirs(dest_dir) + + _check_archive_signature(archive_file) + + # Keep the extracted file names and their content hash values, in order to + # handle duplicate names correctly. + extracted_files = dict() + + for name, file_content in _extract_next_file(archive_file): + digest = hashlib.md5(file_content).digest() + + # Check if the name is already used. If so, come up with a different name by + # incrementing the number suffix until it finds an unused one. + # For example, if 'foo.o' is used, try 'foo_1.o', 'foo_2.o', and so on. + for final_name in _generate_modified_filenames(name): + if final_name not in extracted_files: + extracted_files[final_name] = digest + + # Write the file content to the desired final path. + with open(os.path.join(dest_dir, final_name), 'wb') as object_file: + object_file.write(file_content) + break + + # Skip writing this file if the same file was already extracted. + elif extracted_files[final_name] == digest: + break + + +def _generate_modified_filenames(filename: str) -> Iterator[str]: + """Generates the modified filenames with incremental name suffix added. + + This helper function first yields the given filename itself, and subsequently + yields modified filenames by incrementing number suffix to the basename. + + Args: + filename: The original filename to be modified. + + Yields: + The original filename and then modified filenames with incremental suffix. + """ + yield filename + + base, ext = os.path.splitext(filename) + for name_suffix in itertools.count(1, 1): + yield '{}_{}{}'.format(base, name_suffix, ext) + + +def _check_archive_signature(archive_file: io.BufferedIOBase) -> None: + """Checks if the file has the correct archive header signature. + + The cursor is moved to the first available file header section after + successfully checking the signature. + + Args: + archive_file: The archive file object pointing at its beginning. + + Raises: + RuntimeError: The archive signature is invalid. + """ + signature = archive_file.read(8) + if signature != b'!\n': + raise RuntimeError('Invalid archive file format.') + + +def _extract_next_file( + archive_file: io.BufferedIOBase) -> Iterator[Tuple[str, bytes]]: + """Extracts the next available file from the archive. + + Reads the next available file header section and yields its filename and + content in bytes as a tuple. Stops when there are no more available files in + the provided archive_file. + + Args: + archive_file: The archive file object, of which cursor is pointing to the + next available file header section. + + Yields: + The name and content of the next available file in the given archive file. + + Raises: + RuntimeError: The archive_file is in an unknown format. + """ + while True: + header = archive_file.read(60) + if not header: + return + elif len(header) < 60: + raise RuntimeError('Invalid file header format.') + + # For the details of the file header format, see: + # https://en.wikipedia.org/wiki/Ar_(Unix)#File_header + # We only need the file name and the size values. + name, _, _, _, _, size, end = struct.unpack('=16s12s6s6s8s10s2s', header) + if end != b'`\n': + raise RuntimeError('Invalid file header format.') + + # Convert the bytes into more natural types. + name = name.decode('ascii').strip() + size = int(size, base=10) + odd_size = size % 2 == 1 + + # Handle the extended filename scheme. + if name.startswith('#1/'): + filename_size = int(name[3:]) + name = archive_file.read(filename_size).decode('utf-8').strip(' \x00') + size -= filename_size + + file_content = archive_file.read(size) + # The file contents are always 2 byte aligned, and 1 byte is padded at the + # end in case the size is odd. + if odd_size: + archive_file.read(1) + + yield (name, file_content) diff --git a/tensorflow/lite/ios/extract_object_files_main.py b/tensorflow/lite/ios/extract_object_files_main.py new file mode 100644 index 00000000000..4b23521ed25 --- /dev/null +++ b/tensorflow/lite/ios/extract_object_files_main.py @@ -0,0 +1,38 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Command line tool version of the extract_object_files module. + +This command line tool version takes the archive file path and the destination +directory path as the positional command line arguments. +""" + +import sys +from typing import Sequence +from tensorflow.lite.ios import extract_object_files + + +def main(argv: Sequence[str]) -> None: + if len(argv) != 3: + raise RuntimeError('Usage: {} '.format(argv[0])) + + archive_path = argv[1] + dest_dir = argv[2] + with open(archive_path, 'rb') as archive_file: + extract_object_files.extract_object_files(archive_file, dest_dir) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/tensorflow/lite/ios/extract_object_files_test.py b/tensorflow/lite/ios/extract_object_files_test.py new file mode 100644 index 00000000000..1d50aaf402a --- /dev/null +++ b/tensorflow/lite/ios/extract_object_files_test.py @@ -0,0 +1,79 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Tests for the extract_object_files module.""" + +import io +import os +import pathlib +from typing import List +from absl.testing import parameterized +from tensorflow.lite.ios import extract_object_files +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class ExtractObjectFilesTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='Simple extraction', + dirname='simple', + object_files=['foo.o', 'bar.o']), + dict( + testcase_name='Extended filename', + dirname='extended_filename', + object_files=['short.o', 'long_file_name_with_extended_format.o']), + dict( + testcase_name='Odd bytes pad handling', + dirname='odd_bytes', + object_files=['odd.o', 'even.o']), + dict( + testcase_name='Duplicate object names should be separated out', + dirname='duplicate_names', + object_files=['foo.o', 'foo_1.o', 'foo_2.o']), + dict( + testcase_name='Exact same file should not be extracted again', + dirname='skip_same_file', + object_files=['foo.o'])) + def test_extract_object_files(self, dirname: str, object_files: List[str]): + dest_dir = self.create_tempdir().full_path + input_file_relpath = os.path.join('testdata', dirname, 'input.a') + archive_path = resource_loader.get_path_to_datafile(input_file_relpath) + + with open(archive_path, 'rb') as archive_file: + extract_object_files.extract_object_files(archive_file, dest_dir) + + # Only the expected files should be extracted and no more. + self.assertCountEqual(object_files, os.listdir(dest_dir)) + + # Compare the extracted files against the expected file content. + for file in object_files: + actual = pathlib.Path(os.path.join(dest_dir, file)).read_bytes() + expected = pathlib.Path( + resource_loader.get_path_to_datafile( + os.path.join('testdata', dirname, file))).read_bytes() + self.assertEqual(actual, expected) + + def test_invalid_archive(self): + with io.BytesIO(b'this is an invalid archive file') as archive_file: + with self.assertRaises(RuntimeError): + extract_object_files.extract_object_files( + archive_file, + self.create_tempdir().full_path) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/lite/ios/hide_symbols_with_allowlist.sh b/tensorflow/lite/ios/hide_symbols_with_allowlist.sh index 2b14e39401c..71500e3ef15 100755 --- a/tensorflow/lite/ios/hide_symbols_with_allowlist.sh +++ b/tensorflow/lite/ios/hide_symbols_with_allowlist.sh @@ -22,6 +22,7 @@ # INPUT_FRAMEWORK: a zip file containing the iOS static framework. # BUNDLE_NAME: the pod/bundle name of the iOS static framework. # ALLOWLIST_FILE_PATH: contains the allowed symbols. +# EXTRACT_SCRIPT_PATH: path to the extract_object_files script. # OUTPUT: the output zip file. # Halt on any error or any unknown variable. @@ -84,8 +85,17 @@ for arch in "${archs[@]}"; do echo fi fi - xcrun ar -x "${arch_file}" - mv *.o "${archdir}"/ + if [[ ! -z "${EXTRACT_SCRIPT_PATH}" ]]; then + "${EXTRACT_SCRIPT_PATH}" "${arch_file}" "${archdir}" + else + # ar tool extracts the objects in the current working directory. Since the + # default working directory for a genrule is always the same, there can be + # a race condition when this script is called for multiple targets + # simultaneously. + pushd "${archdir}" > /dev/null + xcrun ar -x "${arch_file}" + popd > /dev/null + fi objects_file_list=$($MKTEMP) # Hides the symbols except the allowed ones. diff --git a/tensorflow/lite/ios/ios.bzl b/tensorflow/lite/ios/ios.bzl index a9e98aafcf4..acb9cabff13 100644 --- a/tensorflow/lite/ios/ios.bzl +++ b/tensorflow/lite/ios/ios.bzl @@ -59,6 +59,7 @@ def tflite_ios_static_framework( cmd = ("INPUT_FRAMEWORK=\"$(location " + framework_target + ")\" " + "BUNDLE_NAME=\"" + bundle_name + "\" " + "ALLOWLIST_FILE_PATH=\"$(location " + allowlist_symbols_file + ")\" " + + "EXTRACT_SCRIPT_PATH=\"$(location //tensorflow/lite/ios:extract_object_files_main)\" " + "OUTPUT=\"$(OUTS)\" " + "\"$(location //tensorflow/lite/ios:hide_symbols_with_allowlist)\"") @@ -68,6 +69,7 @@ def tflite_ios_static_framework( outs = [name + ".zip"], cmd = cmd, tools = [ + "//tensorflow/lite/ios:extract_object_files_main", "//tensorflow/lite/ios:hide_symbols_with_allowlist", ], ) diff --git a/tensorflow/lite/ios/testdata/duplicate_names/foo.o b/tensorflow/lite/ios/testdata/duplicate_names/foo.o new file mode 100644 index 00000000000..f2e60bad120 --- /dev/null +++ b/tensorflow/lite/ios/testdata/duplicate_names/foo.o @@ -0,0 +1 @@ +first foo.o \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/duplicate_names/foo_1.o b/tensorflow/lite/ios/testdata/duplicate_names/foo_1.o new file mode 100644 index 00000000000..72a4e014294 --- /dev/null +++ b/tensorflow/lite/ios/testdata/duplicate_names/foo_1.o @@ -0,0 +1 @@ +second foo.o \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/duplicate_names/foo_2.o b/tensorflow/lite/ios/testdata/duplicate_names/foo_2.o new file mode 100644 index 00000000000..6c39329a4b4 --- /dev/null +++ b/tensorflow/lite/ios/testdata/duplicate_names/foo_2.o @@ -0,0 +1 @@ +third foo.o \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/duplicate_names/input.a b/tensorflow/lite/ios/testdata/duplicate_names/input.a new file mode 100644 index 00000000000..d36c5442cc8 --- /dev/null +++ b/tensorflow/lite/ios/testdata/duplicate_names/input.a @@ -0,0 +1,6 @@ +! +foo.o 1609941687 12549 24403 100644 11 ` +first foo.o +foo.o 1609941704 12549 24403 100644 12 ` +second foo.ofoo.o 1609941712 12549 24403 100644 11 ` +third foo.o diff --git a/tensorflow/lite/ios/testdata/extended_filename/input.a b/tensorflow/lite/ios/testdata/extended_filename/input.a new file mode 100644 index 00000000000..3cb209dac9f Binary files /dev/null and b/tensorflow/lite/ios/testdata/extended_filename/input.a differ diff --git a/tensorflow/lite/ios/testdata/extended_filename/long_file_name_with_extended_format.o b/tensorflow/lite/ios/testdata/extended_filename/long_file_name_with_extended_format.o new file mode 100644 index 00000000000..9d68d0195ee --- /dev/null +++ b/tensorflow/lite/ios/testdata/extended_filename/long_file_name_with_extended_format.o @@ -0,0 +1 @@ +long file name \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/extended_filename/short.o b/tensorflow/lite/ios/testdata/extended_filename/short.o new file mode 100644 index 00000000000..20b0d395d38 --- /dev/null +++ b/tensorflow/lite/ios/testdata/extended_filename/short.o @@ -0,0 +1 @@ +short file content \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/odd_bytes/even.o b/tensorflow/lite/ios/testdata/odd_bytes/even.o new file mode 100644 index 00000000000..fab284cf6a1 --- /dev/null +++ b/tensorflow/lite/ios/testdata/odd_bytes/even.o @@ -0,0 +1 @@ +even bytes \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/odd_bytes/input.a b/tensorflow/lite/ios/testdata/odd_bytes/input.a new file mode 100644 index 00000000000..7e7db0b17af --- /dev/null +++ b/tensorflow/lite/ios/testdata/odd_bytes/input.a @@ -0,0 +1,5 @@ +! +odd.o 1609941182 12549 24403 100664 9 ` +odd bytes +even.o 1609941194 12549 24403 100664 10 ` +even bytes \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/odd_bytes/odd.o b/tensorflow/lite/ios/testdata/odd_bytes/odd.o new file mode 100644 index 00000000000..32a82273b1c --- /dev/null +++ b/tensorflow/lite/ios/testdata/odd_bytes/odd.o @@ -0,0 +1 @@ +odd bytes \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/simple/bar.o b/tensorflow/lite/ios/testdata/simple/bar.o new file mode 100644 index 00000000000..6f99aa599e3 --- /dev/null +++ b/tensorflow/lite/ios/testdata/simple/bar.o @@ -0,0 +1 @@ +bar file content \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/simple/foo.o b/tensorflow/lite/ios/testdata/simple/foo.o new file mode 100644 index 00000000000..07e9bdf3f34 --- /dev/null +++ b/tensorflow/lite/ios/testdata/simple/foo.o @@ -0,0 +1 @@ +foo file content \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/simple/input.a b/tensorflow/lite/ios/testdata/simple/input.a new file mode 100644 index 00000000000..7bd8583d426 --- /dev/null +++ b/tensorflow/lite/ios/testdata/simple/input.a @@ -0,0 +1,4 @@ +! +foo.o 1609934189 12549 24403 100664 16 ` +foo file contentbar.o 1609934193 12549 24403 100664 16 ` +bar file content \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/skip_same_file/foo.o b/tensorflow/lite/ios/testdata/skip_same_file/foo.o new file mode 100644 index 00000000000..07e9bdf3f34 --- /dev/null +++ b/tensorflow/lite/ios/testdata/skip_same_file/foo.o @@ -0,0 +1 @@ +foo file content \ No newline at end of file diff --git a/tensorflow/lite/ios/testdata/skip_same_file/input.a b/tensorflow/lite/ios/testdata/skip_same_file/input.a new file mode 100644 index 00000000000..acbfd9d07c0 --- /dev/null +++ b/tensorflow/lite/ios/testdata/skip_same_file/input.a @@ -0,0 +1,4 @@ +! +foo.o 1610108108 12549 24403 100644 16 ` +foo file contentfoo.o 1610108119 12549 24403 100644 16 ` +foo file content \ No newline at end of file diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index afe0c88504a..d82fddce2cd 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1823,6 +1823,14 @@ def py_strict_binary(name, **kwargs): def py_strict_library(name, **kwargs): native.py_library(name = name, **kwargs) +# Placeholder to use until bazel supports pytype_strict_binary. +def pytype_strict_binary(name, **kwargs): + native.py_binary(name = name, **kwargs) + +# Placeholder to use until bazel supports pytype_strict_library. +def pytype_strict_library(name, **kwargs): + native.py_library(name = name, **kwargs) + # Placeholder to use until bazel supports py_strict_test. def py_strict_test(name, **kwargs): py_test(name = name, **kwargs)