From 3376402afdb00ac37e024de70b59fae3ca31c6ce Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 3 Apr 2020 11:09:17 -0700 Subject: [PATCH] Reference ruy from its new location as a separate GitHub project. PiperOrigin-RevId: 304653289 Change-Id: I2a3ce84177665a4d7f455d455c0e7d71d48a68e9 --- .../lite/experimental/ruy/CONTRIBUTING.md | 28 - tensorflow/lite/experimental/ruy/README.md | 24 - tensorflow/lite/experimental/ruy/WORKSPACE | 17 - tensorflow/lite/experimental/ruy/ruy/BUILD | 954 -- .../lite/experimental/ruy/ruy/allocator.cc | 51 - .../lite/experimental/ruy/ruy/allocator.h | 185 - .../experimental/ruy/ruy/allocator_test.cc | 103 - .../lite/experimental/ruy/ruy/benchmark.cc | 196 - .../lite/experimental/ruy/ruy/block_map.cc | 486 - .../lite/experimental/ruy/ruy/block_map.h | 161 - .../experimental/ruy/ruy/block_map_test.cc | 263 - .../experimental/ruy/ruy/blocking_counter.cc | 49 - .../experimental/ruy/ruy/blocking_counter.h | 62 - .../lite/experimental/ruy/ruy/build_defs.bzl | 40 - .../lite/experimental/ruy/ruy/check_macros.h | 138 - .../experimental/ruy/ruy/check_macros_test.cc | 153 - tensorflow/lite/experimental/ruy/ruy/common.h | 73 - .../lite/experimental/ruy/ruy/context.cc | 109 - .../lite/experimental/ruy/ruy/context.h | 109 - .../lite/experimental/ruy/ruy/context_test.cc | 63 - .../experimental/ruy/ruy/cpu_cache_size.h | 81 - .../lite/experimental/ruy/ruy/detect_arm.cc | 73 - .../lite/experimental/ruy/ruy/detect_arm.h | 29 - .../lite/experimental/ruy/ruy/detect_x86.cc | 101 - .../lite/experimental/ruy/ruy/detect_x86.h | 49 - .../lite/experimental/ruy/ruy/dispatch.h | 482 - .../lite/experimental/ruy/ruy/example.cc | 136 - .../experimental/ruy/ruy/example_advanced.cc | 83 - .../ruy/ruy/have_built_path_for.h | 32 - .../ruy/ruy/have_built_path_for_avx2.cc | 35 - .../ruy/ruy/have_built_path_for_avx512.cc | 35 - .../ruy/ruy/have_built_path_for_avxvnni.cc | 39 - .../ruy/ruy/have_built_path_for_sse42.cc | 39 - .../experimental/ruy/ruy/internal_matrix.h | 388 - tensorflow/lite/experimental/ruy/ruy/kernel.h | 31 - .../lite/experimental/ruy/ruy/kernel_arm.h | 211 - .../lite/experimental/ruy/ruy/kernel_arm32.cc | 2499 ------ .../lite/experimental/ruy/ruy/kernel_arm64.cc | 7835 ----------------- .../lite/experimental/ruy/ruy/kernel_avx2.cc | 1664 ---- .../experimental/ruy/ruy/kernel_avx512.cc | 1820 ---- .../experimental/ruy/ruy/kernel_avxvnni.cc | 435 - .../lite/experimental/ruy/ruy/kernel_common.h | 481 - .../lite/experimental/ruy/ruy/kernel_sse42.cc | 428 - .../lite/experimental/ruy/ruy/kernel_x86.h | 222 - tensorflow/lite/experimental/ruy/ruy/matrix.h | 182 - .../lite/experimental/ruy/ruy/opt_set.h | 51 - tensorflow/lite/experimental/ruy/ruy/pack.h | 98 - .../lite/experimental/ruy/ruy/pack_arm.cc | 1936 ---- .../lite/experimental/ruy/ruy/pack_arm.h | 497 -- .../lite/experimental/ruy/ruy/pack_avx2.cc | 816 -- .../lite/experimental/ruy/ruy/pack_avx512.cc | 693 -- .../lite/experimental/ruy/ruy/pack_avxvnni.cc | 478 - .../lite/experimental/ruy/ruy/pack_common.h | 246 - .../lite/experimental/ruy/ruy/pack_sse42.cc | 471 - .../lite/experimental/ruy/ruy/pack_x86.h | 461 - tensorflow/lite/experimental/ruy/ruy/path.h | 162 - .../lite/experimental/ruy/ruy/platform.h | 156 - tensorflow/lite/experimental/ruy/ruy/pmu.cc | 281 - tensorflow/lite/experimental/ruy/ruy/pmu.h | 44 - .../lite/experimental/ruy/ruy/prepack.h | 108 - .../experimental/ruy/ruy/prepacked_cache.cc | 86 - .../experimental/ruy/ruy/prepacked_cache.h | 130 - .../ruy/ruy/prepacked_cache_test.cc | 210 - .../lite/experimental/ruy/ruy/profiler/BUILD | 60 - .../experimental/ruy/ruy/profiler/README.md | 149 - .../ruy/ruy/profiler/instrumentation.cc | 130 - .../ruy/ruy/profiler/instrumentation.h | 203 - .../experimental/ruy/ruy/profiler/profiler.cc | 109 - .../experimental/ruy/ruy/profiler/profiler.h | 106 - .../experimental/ruy/ruy/profiler/test.cc | 167 - .../ruy/profiler/test_instrumented_library.cc | 59 - .../ruy/profiler/test_instrumented_library.h | 23 - .../experimental/ruy/ruy/profiler/treeview.cc | 248 - .../experimental/ruy/ruy/profiler/treeview.h | 130 - tensorflow/lite/experimental/ruy/ruy/ruy.h | 42 - .../lite/experimental/ruy/ruy/ruy_advanced.h | 69 - .../lite/experimental/ruy/ruy/ruy_test.bzl | 34 - .../experimental/ruy/ruy/ruy_test_ext.bzl | 7 - .../lite/experimental/ruy/ruy/side_pair.h | 64 - .../lite/experimental/ruy/ruy/size_util.h | 93 - .../experimental/ruy/ruy/size_util_test.cc | 101 - tensorflow/lite/experimental/ruy/ruy/spec.h | 118 - tensorflow/lite/experimental/ruy/ruy/test.h | 2125 ----- .../lite/experimental/ruy/ruy/test_fast.cc | 110 - .../lite/experimental/ruy/ruy/test_slow.cc | 71 - .../ruy/ruy/test_special_specs.cc | 163 - .../lite/experimental/ruy/ruy/thread_pool.cc | 200 - .../lite/experimental/ruy/ruy/thread_pool.h | 102 - tensorflow/lite/experimental/ruy/ruy/time.h | 81 - tensorflow/lite/experimental/ruy/ruy/trace.cc | 325 - tensorflow/lite/experimental/ruy/ruy/trace.h | 73 - tensorflow/lite/experimental/ruy/ruy/trmul.cc | 401 - tensorflow/lite/experimental/ruy/ruy/trmul.h | 38 - .../lite/experimental/ruy/ruy/trmul_params.h | 67 - tensorflow/lite/experimental/ruy/ruy/tune.cc | 161 - tensorflow/lite/experimental/ruy/ruy/tune.h | 163 - .../lite/experimental/ruy/ruy/tune_test.cc | 53 - .../lite/experimental/ruy/ruy/tune_tool.cc | 56 - tensorflow/lite/experimental/ruy/ruy/wait.cc | 69 - tensorflow/lite/experimental/ruy/ruy/wait.h | 73 - .../lite/experimental/ruy/ruy/wait_test.cc | 117 - tensorflow/lite/kernels/BUILD | 18 +- .../lite/kernels/cpu_backend_context.cc | 2 +- tensorflow/lite/kernels/cpu_backend_context.h | 2 +- .../kernels/cpu_backend_gemm_custom_gemv.h | 2 +- .../lite/kernels/cpu_backend_gemm_gemmlowp.h | 2 +- .../lite/kernels/cpu_backend_gemm_ruy.h | 4 +- .../lite/kernels/cpu_backend_gemm_test.cc | 2 +- .../lite/kernels/cpu_backend_threadpool.h | 4 +- tensorflow/lite/kernels/internal/BUILD | 14 +- .../internal/depthwiseconv_quantized_test.cc | 2 +- .../depthwiseconv_3x3_filter_common.h | 2 +- .../internal/optimized/depthwiseconv_float.h | 2 +- .../internal/optimized/depthwiseconv_uint8.h | 2 +- .../depthwiseconv_uint8_3x3_filter.h | 2 +- .../kernels/internal/optimized/im2col_utils.h | 2 +- .../internal/optimized/integer_ops/add.h | 2 +- .../internal/optimized/integer_ops/conv.h | 2 +- .../optimized/integer_ops/depthwise_conv.h | 2 +- .../integer_ops/depthwise_conv_3x3_filter.h | 2 +- .../integer_ops/depthwise_conv_hybrid.h | 2 +- .../depthwise_conv_hybrid_3x3_filter.h | 2 +- .../optimized/integer_ops/fully_connected.h | 2 +- .../internal/optimized/integer_ops/mul.h | 2 +- .../internal/optimized/integer_ops/pooling.h | 2 +- .../internal/optimized/neon_tensor_utils.cc | 4 +- .../internal/optimized/optimized_ops.h | 2 +- .../internal/reference/integer_ops/mul.h | 2 +- .../lite/kernels/internal/reference/reduce.h | 2 +- .../internal/reference/reference_ops.h | 2 +- .../kernels/internal/reference/requantize.h | 2 +- .../lite/kernels/internal/reference/sub.h | 2 +- tensorflow/lite/kernels/lstm_eval.cc | 2 +- tensorflow/lite/kernels/rfft2d.cc | 2 +- .../examples/person_detection/Makefile.inc | 1 + tensorflow/lite/micro/tools/make/Makefile | 14 +- .../tools/make/third_party_downloads.inc | 3 + tensorflow/lite/tools/benchmark/BUILD | 2 +- .../tools/benchmark/benchmark_tflite_model.cc | 2 +- tensorflow/lite/tools/make/Makefile | 4 +- .../lite/tools/make/download_dependencies.sh | 2 + tensorflow/tools/ci_build/ci_sanity.sh | 2 + tensorflow/tools/pip_package/BUILD | 1 + tensorflow/workspace.bzl | 2 + third_party/ruy/BUILD | 8 + third_party/ruy/workspace.bzl | 15 + 146 files changed, 97 insertions(+), 34017 deletions(-) delete mode 100644 tensorflow/lite/experimental/ruy/CONTRIBUTING.md delete mode 100644 tensorflow/lite/experimental/ruy/README.md delete mode 100644 tensorflow/lite/experimental/ruy/WORKSPACE delete mode 100644 tensorflow/lite/experimental/ruy/ruy/BUILD delete mode 100644 tensorflow/lite/experimental/ruy/ruy/allocator.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/allocator.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/allocator_test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/benchmark.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/block_map.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/block_map.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/block_map_test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/blocking_counter.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/build_defs.bzl delete mode 100644 tensorflow/lite/experimental/ruy/ruy/check_macros.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/common.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/context.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/context.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/context_test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/detect_arm.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/detect_arm.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/detect_x86.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/detect_x86.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/dispatch.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/example.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/example_advanced.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/internal_matrix.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_arm.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_arm32.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_arm64.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_common.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/kernel_x86.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/matrix.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/opt_set.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_arm.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_arm.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_common.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pack_x86.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/path.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/platform.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pmu.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/pmu.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/prepack.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/BUILD delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/README.md delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/ruy.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl delete mode 100644 tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl delete mode 100644 tensorflow/lite/experimental/ruy/ruy/side_pair.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/size_util.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/size_util_test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/spec.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/test.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/test_fast.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/test_slow.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/thread_pool.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/thread_pool.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/time.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/trace.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/trace.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/trmul.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/trmul.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/trmul_params.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/tune.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/tune.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/tune_test.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/tune_tool.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/wait.cc delete mode 100644 tensorflow/lite/experimental/ruy/ruy/wait.h delete mode 100644 tensorflow/lite/experimental/ruy/ruy/wait_test.cc create mode 100644 third_party/ruy/BUILD create mode 100644 third_party/ruy/workspace.bzl diff --git a/tensorflow/lite/experimental/ruy/CONTRIBUTING.md b/tensorflow/lite/experimental/ruy/CONTRIBUTING.md deleted file mode 100644 index 654a071648d..00000000000 --- a/tensorflow/lite/experimental/ruy/CONTRIBUTING.md +++ /dev/null @@ -1,28 +0,0 @@ -# How to Contribute - -We'd love to accept your patches and contributions to this project. There are -just a few small guidelines you need to follow. - -## Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement. You (or your employer) retain the copyright to your contribution; -this simply gives us permission to use and redistribute your contributions as -part of the project. Head over to to see -your current agreements on file or to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. - -## Code reviews - -All submissions, including submissions by project members, require review. We -use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more -information on using pull requests. - -## Community Guidelines - -This project follows [Google's Open Source Community -Guidelines](https://opensource.google/conduct/). diff --git a/tensorflow/lite/experimental/ruy/README.md b/tensorflow/lite/experimental/ruy/README.md deleted file mode 100644 index 09b85927d09..00000000000 --- a/tensorflow/lite/experimental/ruy/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# The ruy matrix multiplication library - -This is not an officially supported Google product. - -ruy is a matrix multiplication library. Its focus is to cover the matrix -multiplication needs of neural network inference engines. Its initial user has -been TensorFlow Lite, where it is used by default on the ARM CPU architecture. - -ruy supports both floating-point and 8bit-integer-quantized matrices. - -## Efficiency - -ruy is designed to achieve maximal performance not just on very large sizes, as -is the focus of many established libraries, but on whatever are the actual sizes -and shapes of matrices most critical in current TensorFlow Lite applications. -This often means quite small sizes, e.g. 100x100 or even 50x50, and all sorts of -rectangular shapes. - -ruy is currently only optimized for the ARM architectures (both 64-bit and -32-bit code). Optimization for the Intel x86 architecture is in progress. - -ruy is currently optimized only for the following combination of storage orders: -LHS = row-major, RHS = column-major, destination = column-major. All other -combinations of storage orders fall back to slow reference code at the moment. diff --git a/tensorflow/lite/experimental/ruy/WORKSPACE b/tensorflow/lite/experimental/ruy/WORKSPACE deleted file mode 100644 index 8364d8047b1..00000000000 --- a/tensorflow/lite/experimental/ruy/WORKSPACE +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Workspace file for the Ruy project. - -workspace(name = "com_google_ruy") diff --git a/tensorflow/lite/experimental/ruy/ruy/BUILD b/tensorflow/lite/experimental/ruy/ruy/BUILD deleted file mode 100644 index c808c3ec063..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/BUILD +++ /dev/null @@ -1,954 +0,0 @@ -# Ruy is not BLAS - -load(":build_defs.bzl", "ruy_copts_avx2", "ruy_copts_avxvnni", "ruy_copts_base", "ruy_copts_skylake", "ruy_copts_sse42") -load(":ruy_test_ext.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") -load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -config_setting( - name = "windows", - values = {"cpu": "x64_windows"}, -) - -config_setting( - name = "armeabi-v7a", - values = {"cpu": "armeabi-v7a"}, -) - -config_setting( - name = "x86_64", - values = {"cpu": "k8"}, -) - -config_setting( - name = "optimized", - values = { - "compilation_mode": "opt", - }, - visibility = ["//visibility:public"], -) - -cc_library( - name = "platform", - hdrs = ["platform.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "check_macros", - hdrs = ["check_macros.h"], - copts = ruy_copts_base(), -) - -cc_test( - name = "check_macros_test", - srcs = ["check_macros_test.cc"], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "opt_set", - hdrs = ["opt_set.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "time", - hdrs = ["time.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "wait", - srcs = ["wait.cc"], - hdrs = ["wait.h"], - copts = ruy_copts_base(), - deps = [":time"], -) - -cc_test( - name = "wait_test", - srcs = ["wait_test.cc"], - deps = [ - ":platform", - ":wait", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "size_util", - hdrs = ["size_util.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_test( - name = "size_util_test", - srcs = ["size_util_test.cc"], - deps = [ - ":size_util", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "tune", - srcs = [ - "tune.cc", - ], - hdrs = [ - "tune.h", - ], - copts = ruy_copts_base(), - deps = [ - ":opt_set", - ":platform", - ":time", - ], -) - -cc_library( - name = "prepacked_cache", - srcs = [ - "prepacked_cache.cc", - ], - hdrs = [ - "prepacked_cache.h", - ], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":matrix", - ":opt_set", - ":platform", - ":time", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_test( - name = "tune_test", - srcs = ["tune_test.cc"], - deps = [ - ":tune", - "@com_google_googletest//:gtest", - ], -) - -cc_test( - name = "prepacked_cache_test", - srcs = ["prepacked_cache_test.cc"], - deps = [ - ":prepacked_cache", - ":ruy", - ":time", - "@com_google_googletest//:gtest", - ], -) - -cc_binary( - name = "tune_tool", - srcs = ["tune_tool.cc"], - deps = [ - ":tune", - ], -) - -cc_library( - name = "allocator", - srcs = [ - "allocator.cc", - ], - hdrs = [ - "allocator.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":size_util", - ], -) - -cc_test( - name = "allocator_test", - srcs = ["allocator_test.cc"], - deps = [ - ":allocator", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "side_pair", - hdrs = ["side_pair.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_library( - name = "block_map", - srcs = [ - "block_map.cc", - ], - hdrs = [ - "block_map.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":opt_set", - ":path", - ":side_pair", - ":size_util", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_test( - name = "block_map_test", - srcs = ["block_map_test.cc"], - deps = [ - ":block_map", - ":cpu_cache_size", - ":path", - ":side_pair", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "blocking_counter", - srcs = [ - "blocking_counter.cc", - ], - hdrs = [ - "blocking_counter.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":wait", - ], -) - -cc_library( - name = "thread_pool", - srcs = [ - "thread_pool.cc", - ], - hdrs = [ - "thread_pool.h", - ], - copts = ruy_copts_base(), - deps = [ - ":blocking_counter", - ":check_macros", - ":wait", - ], -) - -cc_library( - name = "detect_arm", - srcs = [ - "detect_arm.cc", - ], - hdrs = [ - "detect_arm.h", - ], - copts = ruy_copts_base(), -) - -cc_library( - name = "detect_x86", - srcs = [ - "detect_x86.cc", - ], - hdrs = [ - "detect_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":platform", - ], -) - -cc_library( - name = "path", - hdrs = ["path.h"], - copts = ruy_copts_base(), - deps = [ - ":platform", - ":size_util", - ], -) - -cc_library( - name = "cpu_cache_size", - hdrs = ["cpu_cache_size.h"], - copts = ruy_copts_base(), - deps = [ - ":path", - ":platform", - ], -) - -cc_library( - name = "trace", - srcs = [ - "trace.cc", - ], - hdrs = [ - "trace.h", - ], - copts = ruy_copts_base(), - deps = [ - ":block_map", - ":check_macros", - ":side_pair", - ":time", - ], -) - -cc_library( - name = "matrix", - hdrs = ["matrix.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_library( - name = "spec", - hdrs = ["spec.h"], - copts = ruy_copts_base(), - deps = [ - ":cpu_cache_size", - ":matrix", - ], -) - -cc_library( - name = "internal_matrix", - hdrs = ["internal_matrix.h"], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":matrix", - ":size_util", - ], -) - -cc_library( - name = "common", - hdrs = [ - "common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":path", - ":platform", - ], -) - -cc_library( - name = "kernel_common", - hdrs = [ - "kernel.h", - "kernel_arm.h", - "kernel_common.h", - "kernel_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":path", - ":platform", - ":side_pair", - ":size_util", - ":spec", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_common", - hdrs = [ - "pack.h", - "pack_arm.h", - "pack_common.h", - "pack_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":path", - ":platform", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "kernel_arm", - srcs = [ - "kernel_arm32.cc", - "kernel_arm64.cc", - ], - copts = ruy_copts_base(), - deps = [ - ":common", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_arm", - srcs = [ - "pack_arm.cc", - ], - copts = ruy_copts_base(), - deps = [ - ":common", - ":opt_set", - ":pack_common", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -# AVX-512 compilation units. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX512 = ruy_copts_base() + ruy_copts_skylake() - -cc_library( - name = "kernel_avx512", - srcs = [ - "kernel_avx512.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avx512", - srcs = [ - "pack_avx512.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avx512", - srcs = [ - "have_built_path_for_avx512.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX-512 compilation units. - -# AVX2 compilation units. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX2 = ruy_copts_base() + ruy_copts_avx2() - -cc_library( - name = "kernel_avx2", - srcs = [ - "kernel_avx2.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avx2", - srcs = [ - "pack_avx2.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avx2", - srcs = [ - "have_built_path_for_avx2.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX2 compilation units. - -# SSE42 compilation units. -# -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_SSE42 = ruy_copts_base() + ruy_copts_sse42() - -cc_library( - name = "kernel_sse42", - srcs = [ - "kernel_sse42.cc", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_sse42", - srcs = [ - "pack_sse42.cc", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_sse42", - srcs = [ - "have_built_path_for_sse42.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: SSE42 compilation units. - -# AVX-VNNI compilation units. -# -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX_VNNI = ruy_copts_base() + ruy_copts_avxvnni() - -cc_library( - name = "kernel_avxvnni", - srcs = [ - "kernel_avxvnni.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avxvnni", - srcs = [ - "pack_avxvnni.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avxvnni", - srcs = [ - "have_built_path_for_avxvnni.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX-VNNI compilation units. - -cc_library( - name = "kernel", - hdrs = [ - "kernel.h", - "kernel_common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":kernel_arm", # fixdeps: keep - ":kernel_avx2", # fixdeps: keep - ":kernel_avx512", # fixdeps: keep - ":kernel_avxvnni", # fixdeps: keep - ":kernel_common", - ":kernel_sse42", # fixdeps: keep - ":matrix", - ":opt_set", - ":path", - ":platform", - ":side_pair", - ":size_util", - ":spec", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "pack", - hdrs = [ - "pack.h", - "pack_common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":pack_arm", # fixdeps: keep - ":pack_avx2", # fixdeps: keep - ":pack_avx512", # fixdeps: keep - ":pack_avxvnni", # fixdeps: keep - ":pack_common", - ":pack_sse42", # fixdeps: keep - ":path", - ":platform", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for", - hdrs = [ - "have_built_path_for.h", - ], - deps = [ - ":have_built_path_for_avx2", - ":have_built_path_for_avx512", - ":have_built_path_for_avxvnni", - ":have_built_path_for_sse42", - ":platform", - ], -) - -cc_library( - name = "context", - srcs = [ - "context.cc", - ], - hdrs = [ - "context.h", - ], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":check_macros", - ":detect_arm", - ":detect_x86", - ":have_built_path_for", - ":path", - ":platform", - ":prepacked_cache", - ":thread_pool", - ":trace", - ":tune", - ], -) - -cc_test( - name = "context_test", - srcs = ["context_test.cc"], - deps = [ - ":context", - ":path", - ":platform", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "trmul_params", - hdrs = ["trmul_params.h"], - copts = ruy_copts_base(), - deps = [ - ":internal_matrix", - ":side_pair", - ":tune", - ], -) - -cc_library( - name = "trmul", - srcs = ["trmul.cc"], - hdrs = ["trmul.h"], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":block_map", - ":check_macros", - ":common", - ":context", - ":internal_matrix", - ":matrix", - ":opt_set", - ":side_pair", - ":size_util", - ":spec", - ":thread_pool", - ":trace", - ":trmul_params", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -# The main library. -cc_library( - name = "ruy", - srcs = [ - "dispatch.h", - "prepack.h", - ], - hdrs = [ - "ruy.h", - "ruy_advanced.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":context", - ":internal_matrix", - ":kernel", - ":matrix", - ":opt_set", - ":pack", - ":path", - ":prepacked_cache", - ":side_pair", - ":size_util", - ":spec", - ":trmul", - ":trmul_params", - ":tune", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -# Usage examples. -cc_binary( - name = "example", - srcs = ["example.cc"], - deps = [":ruy"], -) - -# Usage examples of the advanced API. -cc_binary( - name = "example_advanced", - srcs = ["example_advanced.cc"], - deps = [":ruy"], -) - -# Small library to query PMU counters, for benchmark only -cc_library( - name = "pmu", - testonly = True, - srcs = ["pmu.cc"], - hdrs = ["pmu.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -# Testing framework. -cc_library( - name = "test_lib", - testonly = True, - hdrs = ["test.h"], - copts = ruy_copts_base(), - # need defines, not copts, because it's controlling a header, test.h - defines = ruy_test_ext_defines(), - linkopts = select({ - ":windows": [], - "//conditions:default": ["-lm"], - }), - deps = [ - ":matrix", - ":pmu", - ":ruy", - ":spec", - ":time", - "@com_google_googletest//:gtest", - ":platform", - "//tensorflow/lite/experimental/ruy/ruy/profiler:profiler", - ] + ruy_test_ext_deps(), -) - -ruy_benchmark( - name = "benchmark", - srcs = ["benchmark.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", - ], -) - -ruy_test( - name = "test_fast", - srcs = ["test_fast.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("f64", "f32", "f64", "f32"), - ("f32", "f64", "f64", "f64"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("i8", "u8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ("i8", "u8", "i32", "i32"), - ], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "@com_google_googletest//:gtest_main", - ], -) - -ruy_test( - name = "test_slow", - srcs = ["test_slow.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ], - tags = ["slow"], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "@com_google_googletest//:gtest_main", - ], -) - -ruy_test( - name = "test_special_specs", - srcs = ["test_special_specs.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("u8", "u8", "i32", "i16"), - ], - deps = [ - "//tensorflow/lite/experimental/ruy/ruy:test_lib", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/ruy/ruy/allocator.cc b/tensorflow/lite/experimental/ruy/ruy/allocator.cc deleted file mode 100644 index 2c507561f2f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/allocator.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" - -#include -#include - -#ifdef _WIN32 -#include -#endif - -namespace ruy { - -namespace detail { - -void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) { -#ifdef _WIN32 - return _aligned_malloc(num_bytes, kMinimumBlockAlignment); -#else - void *ptr; - if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) { - return nullptr; - } - return ptr; -#endif -} - -void SystemAlignedFree(void *ptr) { -#ifdef _WIN32 - _aligned_free(ptr); -#else - free(ptr); -#endif -} - -} // namespace detail - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/allocator.h b/tensorflow/lite/experimental/ruy/ruy/allocator.h deleted file mode 100644 index 56aa0eef8f9..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/allocator.h +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -namespace detail { - -inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) { - RUY_DCHECK(p); - std::uintptr_t addr = reinterpret_cast(p) + offset; - return reinterpret_cast(addr); -} - -// Minimum alignment for blocks. -// -// Considerations: -// - This needs to be at least the alignment of any usual data type. -// - It's useful that this is at least the size of a cache line to limit -// possible cache side effects (if only on performance behavior). -// - It's useful that this is at least the size of SIMD registers, as -// some SIMD instruction sets have at least performance behavior -// differences (e.g. NEON) or even different requirements (e.g. SSE) -// based on that. -// - It's useful that this is at least the size of an "exclusive reservation -// granule" on ARM, meaning that if we use this Allocator to allocate -// an atomic variable, there will be no side effects from other things -// contending for exclusive/atomic memory accesses to it. While the -// ARM reference manual mentions that this granule size may be as large -// as 2048 bytes, in practice we observe it to be 64 bytes. It can -// be queried cheaply, at runtime, from userspace, if needed. -static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64; - -// Primitive allocation functions obtaining aligned memory from the -// operating system. -void* SystemAlignedAlloc(std::ptrdiff_t num_bytes); -void SystemAlignedFree(void* ptr); - -// Specialized allocator designed to converge to a steady-state where all -// allocations are bump-ptr allocations from an already-allocated buffer. -// -// To support these constraints, this allocator only supports two -// operations. -// - AllocateAlignedBytes: allocates a pointer to storage of a specified -// size, which must be aligned to kMinimumBlockAlignment. -// - FreeAll: frees all previous allocations (but retains the internal -// buffer to minimize future calls into the system allocator). -// -// This class is specialized for supporting just those two operations -// under this specific steady-state usage pattern. Extending this class -// with new allocation interfaces that don't fit that pattern is probably not -// the right choice. Instead, build a new class on top of -// SystemAlignedAlloc/SystemAlignedFree. -// -// All operations happen on aligned blocks for simplicity. -class AlignedAllocator { - public: - void operator=(const AlignedAllocator&) = delete; - ~AlignedAllocator() { - FreeAll(); - SystemAlignedFree(ptr_); - } - - void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) { - RUY_DCHECK_GT(num_bytes, 0); - RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0); - if (void* p = AllocateFast(num_bytes)) { - return p; - } - return AllocateSlow(num_bytes); - } - - void FreeAll() { - current_ = 0; - if (fallback_blocks_.empty()) { - return; - } - - // No rounding-up of the size means linear instead of logarithmic - // bound on the number of allocation in some worst-case calling patterns. - // This is considered worth it because minimizing memory usage is important - // and actual calling patterns in applications that we care about still - // reach the no-further-allocations steady state in a small finite number - // of iterations. - std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_; - SystemAlignedFree(ptr_); - ptr_ = SystemAlignedAlloc(new_size); - size_ = new_size; - - for (void* p : fallback_blocks_) { - SystemAlignedFree(p); - } - fallback_blocks_.clear(); - fallback_blocks_total_size_ = 0; - } - - private: - void* AllocateFast(std::ptrdiff_t num_bytes) { - if (current_ + num_bytes > size_) { - return nullptr; - } - void* ret = VoidPtrAdd(ptr_, current_); - current_ += num_bytes; - return ret; - } - - void* AllocateSlow(std::ptrdiff_t num_bytes) { - void* p = SystemAlignedAlloc(num_bytes); - fallback_blocks_total_size_ += num_bytes; - fallback_blocks_.push_back(p); - return p; - } - - // Theory of operation: - // - // - ptr_, current_, and size_ implement a basic bump-ptr allocator. - // - // - in AllocateAlignedBytes, the fast path is just a bump-ptr - // allocation. If our bump-ptr allocator doesn't have enough space for an - // allocation, then we allocate a block from the system allocator to - // service the allocation request. We save that block in fallback_blocks_ - // and track the total size of the fallback blocks in - // fallback_blocks_total_size_. - // - // - in FreeAll, the fast path just resets the bump-ptr allocator. If - // there are any fallback blocks, we free them and reallocate the - // bump-ptr allocator's buffer so that the next sequence of allocations - // will hopefully not need any fallback blocks. - void* ptr_ = nullptr; - std::ptrdiff_t current_ = 0; - std::ptrdiff_t size_ = 0; - std::vector fallback_blocks_; - std::ptrdiff_t fallback_blocks_total_size_ = 0; -}; - -} // namespace detail - -// The main Allocator class, with a convenient interface for allocating a -// typed buffer. -class Allocator { - public: - void* AllocateBytes(std::ptrdiff_t num_bytes) { - if (num_bytes == 0) { - return nullptr; - } - return aligned.AllocateAlignedBytes( - round_up_pot(num_bytes, detail::kMinimumBlockAlignment)); - } - template - void Allocate(std::ptrdiff_t count, Pointer* out) { - using T = typename std::pointer_traits::element_type; - *out = static_cast(AllocateBytes(count * sizeof(T))); - } - - void FreeAll() { aligned.FreeAll(); } - - private: - detail::AlignedAllocator aligned; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc b/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc deleted file mode 100644 index 1584b86b4cc..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/allocator_test.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" - -#include - -namespace ruy { -namespace { - -TEST(AllocatorTest, ReturnsValidMemory) { - Allocator allocator; - int *p; - allocator.Allocate(1, &p); - ASSERT_NE(p, nullptr); - - // If this is bogus memory, ASan will cause this test to fail. - *p = 42; - - allocator.FreeAll(); -} - -TEST(AllocatorTest, NoLeak) { - Allocator allocator; - // Allocate and free some ridiculously large total amount of memory, so - // that a leak will hopefully cause some sort of resource exhaustion. - // - // Despite the large number of allocations, this test is actually quite - // fast, since our fast-path allocation logic is very fast. - constexpr int kNumAllocations = 100 * 1024; - constexpr int kAllocationSize = 1024 * 1024; - for (int i = 0; i < kNumAllocations; i++) { - char *p; - allocator.Allocate(kAllocationSize, &p); - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, IncreasingSizes) { - Allocator allocator; - // Allocate sizes that increase by small amounts across FreeAll calls. - for (int i = 1; i < 100 * 1024; i++) { - char *p; - allocator.Allocate(i, &p); - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, ManySmallAllocations) { - Allocator allocator; - // Allocate many small allocations between FreeAll calls. - for (int i = 0; i < 10 * 1024; i += 100) { - for (int j = 0; j < i; j++) { - char *p; - allocator.Allocate(1, &p); - } - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, DestructorHandlesMainBumpPtr) { - // This is a white-box test. - Allocator allocator; - allocator.AllocateBytes(1); - allocator.FreeAll(); - // After the call to FreeAll, the allocator will consolidate all of the memory - // into the main bump-ptr allocator's block, which we then expect to be freed - // in the destructor. - // - // We have no test assertions -- we primarily expect that this trigger a leak - // checker and cause the test to fail. -} - -TEST(AllocatorTest, DestructorHandlesFallbackBlocks) { - // This is a white-box test. - Allocator allocator; - // Since we just created the allocator, this will allocate a fallback block, - // which we then expect to be freed in the destructor. - // - // We have no test assertions -- we primarily expect that this trigger a leak - // checker and cause the test to fail. - allocator.AllocateBytes(1); -} - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/benchmark.cc b/tensorflow/lite/experimental/ruy/ruy/benchmark.cc deleted file mode 100644 index 406345cec06..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/benchmark.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; -using TestSetType = - TestSet>; - -struct BenchmarkShape { - int rows; - int depth; - int cols; - int symm_lhs; - int symm_rhs; -}; - -template -std::vector>> BenchmarkRCC( - const BenchmarkShape& shape) { - TestSetType test_set; - test_set.rows = shape.rows; - test_set.depth = shape.depth; - test_set.cols = shape.cols; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.benchmark = true; - const int asymmetry_lhs = shape.symm_lhs ? 0 : 1; - const int asymmetry_rhs = shape.symm_rhs ? 0 : 1; - test_set.lhs_zero_point = SymmetricZeroPoint() + asymmetry_lhs; - test_set.rhs_zero_point = SymmetricZeroPoint() + asymmetry_rhs; - test_set.use_specified_zero_points = true; - test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL"); - test_set.benchmark_prepack_lhs = GetBoolEnvVarOrFalse("PREPACK_LHS"); - test_set.benchmark_prepack_rhs = GetBoolEnvVarOrFalse("PREPACK_RHS"); - test_set.Run(); - return std::move(test_set.results); -} - -std::vector ParseCommaSeparatedInts( - const std::string& comma_separated_ints) { - std::vector result; - for (std::size_t pos = 0; pos < comma_separated_ints.size();) { - std::size_t delim_pos = comma_separated_ints.find(',', pos); - if (delim_pos == std::string::npos) { - delim_pos = comma_separated_ints.size(); - } - result.push_back( - std::stoi(comma_separated_ints.substr(pos, delim_pos - pos))); - pos = delim_pos + 1; - } - return result; -} - -void Benchmark() { - const bool symm_lhs = std::is_floating_point::value || - GetBoolEnvVarOrFalse("SYMM_LHS"); - const bool symm_rhs = std::is_floating_point::value || - GetBoolEnvVarOrFalse("SYMM_RHS"); - const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") || - GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST"); - const int explicit_rows = GetIntEnvVarOrZero("ROWS"); - const int explicit_cols = GetIntEnvVarOrZero("COLS"); - const int explicit_depth = GetIntEnvVarOrZero("DEPTH"); - - std::vector shapes; - - if (benchmark_cubic) { - std::vector sizes; - const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST"); - if (benchmark_cubic_list_env) { - sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env); - } else { - // Often 8 is used for this multiplier, but to check teeny sizes one can - // use 1. - static constexpr int cubic_size_multiplier = 8; - for (int i = 2 * cubic_size_multiplier; - i <= (512 * cubic_size_multiplier); i *= 2) { - sizes.push_back(i); - if (i < (512 * cubic_size_multiplier)) { - sizes.push_back(i * 3 / 2); - } - } - } - for (int i : sizes) { - BenchmarkShape shape; - // Even in cubic mode, one may still override an individual dimension - // to allow testing a batch of rectangular sizes. - shape.rows = explicit_rows ? explicit_rows : i; - shape.cols = explicit_cols ? explicit_cols : i; - shape.depth = explicit_depth ? explicit_depth : i; - shape.symm_lhs = symm_lhs; - shape.symm_rhs = symm_rhs; - shapes.push_back(shape); - } - } else { - BenchmarkShape shape; - shape.rows = explicit_rows; - shape.cols = explicit_cols; - shape.depth = explicit_depth; - if (!shape.rows || !shape.depth || !shape.cols) { - fprintf(stderr, - "Please specify positive sizes with these env vars: ROWS, DEPTH, " - "COLS.\n"); - exit(1); - } - shape.symm_lhs = symm_lhs; - shape.symm_rhs = symm_rhs; - shapes.push_back(shape); - } - - for (int i = 0; i < shapes.size(); i++) { - const auto& shape = shapes[i]; - const auto& results = BenchmarkRCC(shape); - if (i == 0) { - if (benchmark_cubic) { - printf("size"); - for (const auto& result : results) { - if (results.size() > 1) { - printf(",%s:Gop/s", PathName(*result).c_str()); - } else { - printf(",Gop/s"); - } - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf( - ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill," - "mispred,frontend_stall,backend_stall"); - } - } - printf("\n"); - } else { - printf("path,shape,Gop/s\n"); - } - fflush(stdout); - } - if (benchmark_cubic) { - printf("%d", shape.rows); - for (const auto& result : results) { - printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth / - result->latency); - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", - result->l1_refill_rate, result->l2_refill_rate, - result->l3_refill_rate, result->l1tlb_refill_rate, - result->l2tlb_refill_rate, result->mispred_rate, - result->frontend_stall_rate, result->backend_stall_rate); - } - } - printf("\n"); - fflush(stdout); - } else { - for (const auto& result : results) { - printf( - "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows, - shape.depth, shape.cols, - 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency); - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", - result->l1_refill_rate, result->l2_refill_rate, - result->l3_refill_rate, result->l1tlb_refill_rate, - result->l2tlb_refill_rate, result->mispred_rate, - result->frontend_stall_rate, result->backend_stall_rate); - } - printf("\n"); - } - fflush(stdout); - } - } -} - -} // namespace ruy - -int main() { ruy::Benchmark(); } diff --git a/tensorflow/lite/experimental/ruy/ruy/block_map.cc b/tensorflow/lite/experimental/ruy/ruy/block_map.cc deleted file mode 100644 index 32781d82ad3..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/block_map.cc +++ /dev/null @@ -1,486 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" - -#include -#include - -#ifdef RUY_MAKEBLOCKMAP_DEBUG -#include -#include -#include -#endif - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -namespace { - -void DecodeTraversalLinear(int size_log2, std::uint32_t square_index, - SidePair* local_pos) { - (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1); - (*local_pos)[Side::kRhs] = square_index >> size_log2; -} - -void DecodeTraversalFractalZ(std::uint32_t square_index, - SidePair* local_pos) { - const std::uint32_t n1 = square_index; - const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) | - ((n1 & 0x22222222u) << 1); - const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) | - ((n2 & 0x0c0c0c0cu) << 2); - const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) | - ((n4 & 0x00f000f0u) << 4); - const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) | - ((n8 & 0x0000ff00u) << 8); - (*local_pos)[Side::kLhs] = n16 & 0xffff; - (*local_pos)[Side::kRhs] = n16 >> 16; -} - -void DecodeTraversalFractalU(std::uint32_t square_index, - SidePair* local_pos) { - DecodeTraversalFractalZ(square_index, local_pos); - // Change fractal z-order to u-order - (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs]; -} - -// Code inspired by the sample code in -// https://en.wikipedia.org/wiki/Hilbert_curve -// The main optimization is to avoid hard-to-predict conditional branches -// based on the bits of the square_index parameter. -void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index, - SidePair* local_pos) { - std::uint32_t t = square_index; - std::uint32_t x = 0; - std::uint32_t y = 0; - // Easy-to-predict for loop, the number of iterations is the same for - // an entire GEMM. - for (int sb = 0; sb < size_log2; sb++) { - std::uint32_t s = 1 << sb; - bool rx = t & 2; - bool ry = (t & 1) ^ rx; - std::uint32_t tmp = rx ? (s - 1 - x) : x; - x = ry ? x : rx ? (s - 1 - y) : y; - y = ry ? (y + s) : tmp; - x = rx ? (x + s) : x; - t >>= 2; - } - (*local_pos)[Side::kLhs] = y; - (*local_pos)[Side::kRhs] = x; -} - -} // end anonymous namespace - -void GetBlockByIndex(const BlockMap& block_map, int index, - SidePair* block) { - profiler::ScopeLabel label("GetBlockByIndex"); - const std::uint32_t index_u32 = index; - - const std::uint32_t num_blocks_per_local_curve = - 1u << (2 * block_map.num_blocks_base_log2); - const std::uint32_t square_index = - index_u32 & (num_blocks_per_local_curve - 1); - - const int size_log2 = block_map.num_blocks_base_log2; - SidePair local_pos; - switch (block_map.traversal_order) { - case BlockMapTraversalOrder::kFractalZ: - DecodeTraversalFractalZ(square_index, &local_pos); - break; - case BlockMapTraversalOrder::kFractalU: - DecodeTraversalFractalU(square_index, &local_pos); - break; - case BlockMapTraversalOrder::kFractalHilbert: - DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos); - break; - default: - RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear); - DecodeTraversalLinear(size_log2, square_index, &local_pos); - break; - } - - const std::uint32_t rectangular_index = - index_u32 >> 2 * block_map.num_blocks_base_log2; - for (Side side : {Side::kLhs, Side::kRhs}) { - const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1; - const int rectangular_offset = (rectangular_index & mask) - << block_map.num_blocks_base_log2; - (*block)[side] = local_pos[side] + rectangular_offset; - } -} - -BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, - int lhs_scalar_size, - int rhs_scalar_size, - int local_data_cache_size, - int shared_data_cache_size) { - const int kFractalOptSets = - RUY_OPT_FRACTAL_Z | RUY_OPT_FRACTAL_U | RUY_OPT_FRACTAL_HILBERT; - const int working_set_size = - (lhs_scalar_size * rows + rhs_scalar_size * cols) * depth; - if (RUY_OPT_ENABLED(kFractalOptSets) && - (working_set_size > local_data_cache_size)) { - if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_HILBERT) && - (working_set_size > shared_data_cache_size)) { - return BlockMapTraversalOrder::kFractalHilbert; - } else if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_U)) { - return BlockMapTraversalOrder::kFractalU; - } else { - return BlockMapTraversalOrder::kFractalZ; - } - } else { - return BlockMapTraversalOrder::kLinear; - } -} - -namespace { - -int floor_log2_quotient(int num, int denom) { - if (num <= denom) { - return 0; - } - int log2_quotient = floor_log2(num) - ceil_log2(denom); - if ((denom << (log2_quotient + 1)) <= num) { - log2_quotient++; - } - return log2_quotient; -} - -// Computes the rectangularness of the matrix shape (rows, cols). This is -// essentially just the log2 of the quotient (rows / cols). The kernel_rows and -// kernel_cols only get into the picture for clamping bounds but don't affect -// the generic computation. -void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols, - int* rows_rectangularness_log2, - int* cols_rectangularness_log2) { - *rows_rectangularness_log2 = 0; - *cols_rectangularness_log2 = 0; - - // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel - // itself, we risk having too small kernel blocks for good kernel - // amortization. We avoid that by limiting recangularness so that kernel - // blocks are not too tiny at least in that dimension. Specifically, we try to - // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each - // kernel block along the large dimension. - const int min_kernel_inner_loop_runs_log2 = 3; - if (rows > cols) { - int cols_of_kernel_inner_loop_runs_log2 = - ceil_log2(cols) - pot_log2(kernel_cols); - int min_rows_of_kernel_inner_loop_runs_log2 = - std::max(0, min_kernel_inner_loop_runs_log2 - - cols_of_kernel_inner_loop_runs_log2); - *rows_rectangularness_log2 = - std::min(floor_log2_quotient(rows, cols), - std::max(0, floor_log2(rows) - pot_log2(kernel_rows) - - min_rows_of_kernel_inner_loop_runs_log2)); - // Sanity check that we did not over-estimate rows_rectangularness_log2. - RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols); - } else if (cols > rows) { - int rows_of_kernel_inner_loop_runs_log2 = - ceil_log2(rows) - pot_log2(kernel_rows); - int min_cols_of_kernel_inner_loop_runs_log2 = - std::max(0, min_kernel_inner_loop_runs_log2 - - rows_of_kernel_inner_loop_runs_log2); - *cols_rectangularness_log2 = - std::min(floor_log2_quotient(cols, rows), - std::max(0, floor_log2(cols) - pot_log2(kernel_cols) - - min_cols_of_kernel_inner_loop_runs_log2)); - // Sanity check that we did not over-estimate cols_rectangularness_log2. - RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows); - } - RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2); -} - -// Computes a 'multithreading score'. When multithreading, we need there to -// be at least as many tiles as there are threads, and hopefully -// substantially more than that, so we benefit from ruy's ability to -// dispatch fine-grained workloads to threads. -int GetMultithreadingScore(int block_size_log2, int rows, int cols, - int tentative_thread_count) { - const int num_full_blocks_of_rows = rows >> block_size_log2; - const int num_full_blocks_of_cols = cols >> block_size_log2; - const int candidate_num_full_blocks_log2 = floor_log2( - std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols)); - - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (tentative_thread_count == 1) { - return 0; - } else { - const int blocks_per_thread_log2 = - candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count); - if (blocks_per_thread_log2 < 0) { - return -64; - } else if (blocks_per_thread_log2 == 0) { - return -16; - } else if (blocks_per_thread_log2 == 1) { - return -8; - } else if (blocks_per_thread_log2 == 2) { - return 0; - } else if (blocks_per_thread_log2 == 3) { - return 8; - } else { - return 16; - } - } -} - -// Computes a 'cache locality score'. -int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth, - int kernel_rows_log2, int kernel_cols_log2, - int lhs_scalar_size, int rhs_scalar_size, Path path, - int local_data_cache_size) { - // In the narrow case (e.g. matrix*vector), each byte of the big operand - // matrix (either LHS or RHS) is traversed only once, so any notion of data - // locality is irrelevant. Ignore the 'cache locality score' by forcing it to - // be 0 in that case. - if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) { - return 0; - } - const int block_rows = std::min(1 << block_size_log2, rows); - const int block_cols = std::min(1 << block_size_log2, cols); - const int total_read_bytes = - (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth; - const int total_read_bytes_log2 = ceil_log2(total_read_bytes); - const int nonlocality_log2 = - total_read_bytes_log2 - floor_log2(local_data_cache_size); - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (nonlocality_log2 < -1) { - return 64; - } else if (nonlocality_log2 == -1) { - return 56; - } else if (nonlocality_log2 == 0) { - return 48; - } else if (nonlocality_log2 == 1) { - return 32; - } else if (nonlocality_log2 == 2) { - return 16; - } else if (nonlocality_log2 == 3) { - return 0; - } else { - return -64; - } -} - -// Compute a 'kernel amortization score'. This is the notion that very small -// tiles result in more overhead outside of kernels, more complex memory -// access patterns and less benefits from ruy's fat kernels, so we reward -// larger blocks more than smaller ones. -int GetKernelAmortizationScore(int block_size_log2, int rows, int cols, - int kernel_rows_log2, int kernel_cols_log2) { - const int block_rows = std::min(1 << block_size_log2, rows); - const int block_cols = std::min(1 << block_size_log2, cols); - const int kernels_per_block_log2 = - floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2; - RUY_DCHECK_GE(kernels_per_block_log2, 0); - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (kernels_per_block_log2 == 0) { - return 0; - } else if (kernels_per_block_log2 == 1) { - return 8; - } else if (kernels_per_block_log2 == 2) { - return 16; - } else if (kernels_per_block_log2 == 3) { - return 24; - } else if (kernels_per_block_log2 == 4) { - return 32; - } else if (kernels_per_block_log2 == 5) { - return 40; - } else if (kernels_per_block_log2 == 6) { - return 48; - } else if (kernels_per_block_log2 == 7) { - return 56; - } else { - return 64; - } -} - -} // namespace - -void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, - int tentative_thread_count, Path path, - int local_data_cache_size, int shared_data_cache_size, - BlockMap* block_map) { - profiler::ScopeLabel label("MakeBlockMap"); - -#ifdef RUY_MAKEBLOCKMAP_DEBUG -#if RUY_MAKEBLOCKMAP_DEBUG >= 2 - static constexpr bool debug_everytime = true; -#else - static constexpr bool debug_everytime = false; -#endif - static bool firsttime = true; - if (firsttime || debug_everytime) { - fprintf(stderr, - "MakeBlockMap(rows=%d, cols=%d, depth=%d, kernel_rows=%d, " - "kernel_cols=%d, lhs_scalar_size=%d, rhs_scalar_size=%d, " - "tentative_thread_count=%d)\n", - rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, - rhs_scalar_size, tentative_thread_count); - } -#endif - - RUY_DCHECK_GE(rows, kernel_rows); - RUY_DCHECK_GE(cols, kernel_cols); - RUY_DCHECK_EQ(rows % kernel_rows, 0); - RUY_DCHECK_EQ(cols % kernel_cols, 0); - - block_map->traversal_order = - GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, - local_data_cache_size, shared_data_cache_size); - - int rows_rectangularness_log2 = 0; - int cols_rectangularness_log2 = 0; - GetRectangularness(rows, cols, kernel_rows, kernel_cols, - &rows_rectangularness_log2, &cols_rectangularness_log2); - - const int kernel_rows_log2 = pot_log2(kernel_rows); - const int kernel_cols_log2 = pot_log2(kernel_cols); - const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2); - - const int size = std::min(rows, cols); - const int size_log2 = std::max(kernel_size_log2, floor_log2(size)); - - RUY_DCHECK_GE(size_log2, kernel_size_log2); - - // We are going to try candidate values for block_size_log2 ranging from - // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2). - // For each of them we will compute a 'score' by adding individual scores - // for a few different considerations, all of which is entirely empirical. - // The values (and possibly the logic) around here are all subject to tuning - // based on benchmarks on different hardware. The current values are based - // on benchmarking on Qualcomm S855 (big and little cores), arm64, - // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead - // and tune this as needed to achieve good performance elsewhere. Use - // the unit test, block_map_test, to encode values that should be preserved - // on specific architectures. Use RUY_MAKEBLOCKMAP_DEBUG to help tuning this. - static constexpr int kMaxKernelsPerBlockLog2 = 6; - const int max_block_size_log2 = - std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2); - int best_score = std::numeric_limits::min(); - int best_score_block_size_log2 = -1; - for (int block_size_log2 = kernel_size_log2; - block_size_log2 <= max_block_size_log2; block_size_log2++) { - const int multithreading_score = GetMultithreadingScore( - block_size_log2, rows, cols, tentative_thread_count); - const int cache_locality_score = GetCacheLocalityScore( - block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2, - lhs_scalar_size, rhs_scalar_size, path, local_data_cache_size); - const int kernel_amortization_score = GetKernelAmortizationScore( - block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2); - const int score = - multithreading_score + cache_locality_score + kernel_amortization_score; -#ifdef RUY_MAKEBLOCKMAP_DEBUG - if (firsttime || debug_everytime) { - fprintf(stderr, - "block_size_log2=%d: score=%d multithreading_score=%d " - "cache_locality_score=%d kernel_amortization_score=%d\n", - block_size_log2, score, multithreading_score, - cache_locality_score, kernel_amortization_score); - } -#endif - if (score >= best_score) { - best_score = score; - best_score_block_size_log2 = block_size_log2; - } - } - -#ifdef RUY_MAKEBLOCKMAP_DEBUG - if (firsttime || debug_everytime) { - fprintf(stderr, "best_score_block_size_log2=%d\n", - best_score_block_size_log2); - } - - static const char* explicit_block_size_log2_env = - getenv("RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2"); - if (explicit_block_size_log2_env) { - best_score_block_size_log2 = std::stoi(explicit_block_size_log2_env); - if (firsttime || debug_everytime) { - fprintf(stderr, "Overridden best_score_block_size_log2=%d\n", - best_score_block_size_log2); - } - } - firsttime = false; -#endif - - int num_blocks_base_log2 = size_log2 - best_score_block_size_log2; - RUY_DCHECK_GE(num_blocks_base_log2, 0); - - const int num_blocks_of_rows_log2 = - num_blocks_base_log2 + rows_rectangularness_log2; - const int num_blocks_of_cols_log2 = - num_blocks_base_log2 + cols_rectangularness_log2; - - const int smallr = - round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows); - const int smallc = - round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols); - const int missr = - round_up_pot(rows - (smallr << num_blocks_of_rows_log2), kernel_rows) >> - pot_log2(kernel_rows); - const int missc = - round_up_pot(cols - (smallc << num_blocks_of_cols_log2), kernel_cols) >> - pot_log2(kernel_cols); - - block_map->dims[Side::kLhs] = rows; - block_map->dims[Side::kRhs] = cols; - block_map->kernel_dims[Side::kLhs] = kernel_rows; - block_map->kernel_dims[Side::kRhs] = kernel_cols; - block_map->num_blocks_base_log2 = num_blocks_base_log2; - block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2; - block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2; - block_map->small_block_dims[Side::kLhs] = smallr; - block_map->small_block_dims[Side::kRhs] = smallc; - block_map->large_blocks[Side::kLhs] = missr; - block_map->large_blocks[Side::kRhs] = missc; - // Done last: NumBlocks needs some of the block_map fields to be already set. - block_map->thread_count = - std::min(tentative_thread_count, NumBlocks(*block_map)); -} - -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, - int* start, int* end) { - profiler::ScopeLabel label("GetBlockMatrixCoords"); - *start = block * block_map.small_block_dims[side] + - std::min(block, block_map.large_blocks[side]) * - block_map.kernel_dims[side]; - *end = - *start + block_map.small_block_dims[side] + - (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); - - RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]); - RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]); - RUY_DCHECK_LE(*end, block_map.dims[side]); - RUY_DCHECK_LT(*start, *end); - RUY_DCHECK_GE(*start, 0); -} - -void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, - SidePair* start, SidePair* end) { - for (Side side : {Side::kLhs, Side::kRhs}) { - GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side], - &(*end)[side]); - } -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/block_map.h b/tensorflow/lite/experimental/ruy/ruy/block_map.h deleted file mode 100644 index 0fa4c9d5d60..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/block_map.h +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { - -enum class BlockMapTraversalOrder { - // Plain old row-by-row or column-by-column traversal. - kLinear, - // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve - kFractalZ, - // Variant of Z-order doing a U instead of a Z. - kFractalU, - // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve - kFractalHilbert -}; - -// A BlockMap describes a tiling of a matrix, typically the destination matrix -// of a matrix multiplication computation. As is standard in matrix -// multiplication, a tile is called a "block". -// -// Ruy subdivides work by blocks of the destination matrix: each thread fully -// computes a block at once, then moves on to another block; each block is -// produced by a single thread. -// -// This ensures that the workloads for each block are mutually independent, -// which reduces synchronization requirements. -// -// Typically, a matrix multiplication will early on create a BlockMap by -// calling MakeBlockMap. It will then query the number of blocks in that -// BlockMap by calling NumBlocks. It will then create a single atomic integer -// counter indexing these blocks, called the 'index', and will distribute -// work to its N threads by ensuring that each thread works on disjoint sets -// of index values. For a given index value, the thread will call -// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords -// to find the actual row and column numbers of this block. -// -// There are two nested levels of subdivision. On a local level, the matrix is -// tiled into a square NxN grid where N is a power of two, specifically: -// N = 2^num_blocks_base_log2. -// -// At a larger scale, around these blocks, there may be one further -// level of subdivision, in only one dimension: either along rows or along -// columns. That is used to handle arbitrarily rectangular matrices. The -// aforementioned high-level block grid is square, so it does not readily fit -// well very rectangular matrices. -// -// Taking together these two nested levels of subdivision, the effective -// tiling is by -// 2^(num_blocks_base_log2 + rows_rectangularness_log2) -// blocks in the row dimension, and by -// 2^(num_blocks_base_log2 + cols_rectangularness_log2) -// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols. -// -// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero. -// -// Finally, this BlockMap is designed to operate under alignment constraints: -// two fields, kernel_rows and kernel_cols, describe the requested alignment -// of the effective grid in both dimensions. The idea is to feed matrix -// multiplication kernels with tiles that fit their width as much as possible. -// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp. -// kernel_cols) then some tile will have to have unaligned size. BlockMap -// will only allow that to happen in the last position along each axis, so -// as to minimize the overhead incurred onto the matrix multiplication kernels. -struct BlockMap { - // The number of threads to use (to distribute the blocks to). - int thread_count; - // The order in which to traverse the matrix of which this BlockMap represents - // a tiling (hereafter "the matrix"). - BlockMapTraversalOrder traversal_order; - // The dimensions of the block_map, that is, of the destination - // matrix rounded up to next multiples of kernel_dims. - SidePair dims; - // Log2 of the minimum number of subdivisions of the grid along either axis. - int num_blocks_base_log2; - // Log2 of the additional subdivision of the rows/columns axis. - SidePair rectangularness_log2; - // Requested alignment of the subdivisions of the grid along the rows/columns - // axis. - SidePair kernel_dims; - // Internal helper. Minimum number of rows/columns in each block. - SidePair small_block_dims; - // Internal helper. Number of blocks along each dimension that need to have - // their size in that dimension be given by (small_block_dims + kernel_dims) - // instead of just small_block_dims. - SidePair large_blocks; -}; - -// Returns the traversal order to be used for the given matrix multiplication -// parameters. -BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, - int lhs_scalar_size, - int rhs_scalar_size, - int local_data_cache_size, - int shared_data_cache_size); - -// Create a BlockMap suitable for tiling the destination matrix in a -// matrix multiplication with the given parameters. -void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, - int tentative_thread_count, Path path, - int local_data_cache_size, int shared_data_cache_size, - BlockMap* block_map); - -// Maps an integer index to a block position in the grid. -void GetBlockByIndex(const BlockMap& block_map, int index, - SidePair* block); - -// Given a block position in the grid, returns its actual -// position in the matrix that the BlockMap refers to in the dimension -// referred to by `side`: along rows if side==kLhs, along columns if -// side==kRhs. -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, - int* start, int* end); - -// Given a block position in the grid, returns its actual -// position in the matrix that the BlockMap refers to in terms of -// actual row/column indices. -void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, - SidePair* start, SidePair* end); - -// Returns the number of grid subdivisions along the rows dimension (if -// side == kLhs) or columns dimension (if side == kRhs). -inline int NumBlocksPerSide(Side side, const BlockMap& block_map) { - return 1 << (block_map.num_blocks_base_log2 + - block_map.rectangularness_log2[side]); -} - -// Returns the overall number of blocks in -// the BlockMap. The valid index values to pass to GetBlockByIndex are the -// integers from 0 to N-1 where N is the value returned here. -// -// Note that it is always true that -// NumBlocks == NumBlocksOfRows * NumBlocksOfCols -// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0. -inline int NumBlocks(const BlockMap& block_map) { - return 1 << (2 * block_map.num_blocks_base_log2 + - block_map.rectangularness_log2[Side::kLhs] + - block_map.rectangularness_log2[Side::kRhs]); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc b/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc deleted file mode 100644 index cdd7ee0e01f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/block_map_test.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" - -#include -#include -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { -namespace { - -#if RUY_PLATFORM(NEON_64) - -// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55. -void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, - int rhs_scalar_size, int tentative_thread_count, - Path path, int expected_num_blocks_base_log2, - int expected_rectangularness_log2) { - BlockMap block_map; - MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, - rhs_scalar_size, tentative_thread_count, path, - LocalDataCacheSize(path), SharedDataCacheSize(path), &block_map); - EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2); - EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs], - block_map.rectangularness_log2[Side::kRhs]), - 0); - EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs], - block_map.rectangularness_log2[Side::kRhs]), - expected_rectangularness_log2); -} - -TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) { - MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, - MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) { - MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 2, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 2, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, MakeBlockMapTuningTest32bit) { - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 3, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 7, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) { - MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 3); - MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 6); -} - -#endif - -int L1Distance(const SidePair& a, const SidePair& b) { - return std::abs(a[Side::kLhs] - b[Side::kLhs]) + - std::abs(a[Side::kRhs] - b[Side::kRhs]); -} - -void GetBlockByIndexSquareTest(int num_blocks_base_log2, - BlockMapTraversalOrder traversal_order) { - // Arbitrary, does not affect this test. 3 is just a typical value. - constexpr int kKernelSizeLog2 = 3; - - const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2; - BlockMap block_map; - block_map.thread_count = 1; - block_map.traversal_order = traversal_order; - block_map.num_blocks_base_log2 = num_blocks_base_log2; - for (Side side : {Side::kLhs, Side::kRhs}) { - block_map.dims[side] = 1 << size_log2; - block_map.rectangularness_log2[side] = 0; - block_map.kernel_dims[side] = 1 << kKernelSizeLog2; - block_map.small_block_dims[side] = block_map.kernel_dims[side]; - block_map.large_blocks[side] = 0; - } - - const int num_blocks_per_side = 1 << num_blocks_base_log2; - const int num_blocks = num_blocks_per_side * num_blocks_per_side; - EXPECT_EQ(num_blocks, NumBlocks(block_map)); - - // Perform a full traversal of all blocks, as if computing a whole matrix - // multiplication. - // - // Used to record how many times each block was hit by the traversal. - std::vector block_hit_counts(num_blocks); - // Here we guard an assumption that all traversal orders start at (0, 0). - SidePair previous_block_coords(0, 0); - // Sum of L1 norm of the coordinate change at every step of the traversal. - std::int64_t total_l1_distance = 0; - // Number of jumps i.e. traversal steps with a L1 norm greater than 1. - int discontinuity_count = 0; - for (int block_index = 0; block_index < num_blocks; block_index++) { - SidePair block_coords; - GetBlockByIndex(block_map, block_index, &block_coords); - ++block_hit_counts[block_coords[Side::kLhs] + - num_blocks_per_side * block_coords[Side::kRhs]]; - int distance = L1Distance(block_coords, previous_block_coords); - total_l1_distance += distance; - discontinuity_count += (distance > 1); - previous_block_coords = block_coords; - } - - // Verify that each block was traversed exactly once. - for (int l = 0; l < num_blocks_per_side; l++) { - for (int r = 0; r < num_blocks_per_side; r++) { - EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1); - } - } - - // Verify that the discontinuity_count and total_l1_distance are as expected - // for the given traversal_order. - switch (traversal_order) { - case BlockMapTraversalOrder::kFractalHilbert: - // No discontinuity at all with this space-filling continuous curve! - EXPECT_EQ(discontinuity_count, 0); - // Therefore, total_l1_distance has to be the number of blocks minus one. - EXPECT_EQ(total_l1_distance, num_blocks - 1); - break; - case BlockMapTraversalOrder::kLinear: - EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1); - EXPECT_EQ(total_l1_distance, - 2 * num_blocks_per_side * (num_blocks_per_side - 1)); - break; - case BlockMapTraversalOrder::kFractalZ: - EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0); - EXPECT_EQ(total_l1_distance, - 2 * num_blocks_per_side * (num_blocks_per_side - 1)); - break; - case BlockMapTraversalOrder::kFractalU: { - if (num_blocks_base_log2 == 0) { - EXPECT_EQ(discontinuity_count, 0); - EXPECT_EQ(total_l1_distance, 0); - } else { - int expected_discontinuity_count = 0; - int expected_total_l1_distance = 3; - for (int i = 2; i <= num_blocks_base_log2; i++) { - expected_discontinuity_count = 4 * expected_discontinuity_count + 2; - expected_total_l1_distance = - 4 * expected_total_l1_distance + (1 << (i + 1)) - 1; - } - EXPECT_EQ(discontinuity_count, expected_discontinuity_count); - EXPECT_EQ(total_l1_distance, expected_total_l1_distance); - } - break; - } - default: - abort(); - } -} - -TEST(BlockMapTest, GetBlockByIndexSquare) { - for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10; - num_blocks_base_log2++) { - for (BlockMapTraversalOrder traversal_order : - {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ, - BlockMapTraversalOrder::kFractalU, - BlockMapTraversalOrder::kFractalHilbert}) { - GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order); - } - } -} - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc deleted file mode 100644 index d313ffce51b..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/blocking_counter.h" - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -namespace ruy { - -void BlockingCounter::Reset(int initial_count) { - int old_count_value = count_.load(std::memory_order_relaxed); - RUY_DCHECK_EQ(old_count_value, 0); - (void)old_count_value; - count_.store(initial_count, std::memory_order_release); -} - -bool BlockingCounter::DecrementCount() { - int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); - RUY_DCHECK_GT(old_count_value, 0); - int count_value = old_count_value - 1; - bool hit_zero = (count_value == 0); - if (hit_zero) { - std::lock_guard lock(count_mutex_); - count_cond_.notify_all(); - } - return hit_zero; -} - -void BlockingCounter::Wait() { - const auto& condition = [this]() { - return count_.load(std::memory_order_acquire) == 0; - }; - ruy::Wait(condition, &count_cond_, &count_mutex_); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h b/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h deleted file mode 100644 index 878f0e7219e..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/blocking_counter.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ - -#include -#include // NOLINT(build/c++11) // IWYU pragma: keep -#include // NOLINT(build/c++11) // IWYU pragma: keep - -namespace ruy { - -// A BlockingCounter lets one thread to wait for N events to occur. -// This is how the master thread waits for all the worker threads -// to have finished working. -// The waiting is done using a naive spinlock waiting for the atomic -// count_ to hit the value 0. This is acceptable because in our usage -// pattern, BlockingCounter is used only to synchronize threads after -// short-lived tasks (performing parts of the same GEMM). It is not used -// for synchronizing longer waits (resuming work on the next GEMM). -class BlockingCounter { - public: - BlockingCounter() : count_(0) {} - - // Sets/resets the counter; initial_count is the number of - // decrementing events that the Wait() call will be waiting for. - void Reset(int initial_count); - - // Decrements the counter; if the counter hits zero, signals - // the threads that were waiting for that, and returns true. - // Otherwise (if the decremented count is still nonzero), - // returns false. - bool DecrementCount(); - - // Waits for the N other threads (N having been set by Reset()) - // to hit the BlockingCounter. - void Wait(); - - private: - std::atomic count_; - - // The condition variable and mutex allowing to passively wait for count_ - // to reach the value zero, in the case of longer waits. - std::condition_variable count_cond_; - std::mutex count_mutex_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/build_defs.bzl b/tensorflow/lite/experimental/ruy/ruy/build_defs.bzl deleted file mode 100644 index 9bccccf6316..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/build_defs.bzl +++ /dev/null @@ -1,40 +0,0 @@ -"""Build definitions for Ruy.""" - -# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support -# ARM32 without NEON then we'll implement runtime detection and dispatch at that point. -# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". - -def ruy_copts_base(): - return select({ - ":armeabi-v7a": [ - "-mfpu=neon", - ], - "//conditions:default": [], - }) + select({ - ":optimized": ["-O3"], - "//conditions:default": [], - }) - -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_skylake(): - return [] - -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_avx2(): - return [] - -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_sse42(): - return [] - -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_avxvnni(): - return [] diff --git a/tensorflow/lite/experimental/ruy/ruy/check_macros.h b/tensorflow/lite/experimental/ruy/ruy/check_macros.h deleted file mode 100644 index 773f37d99f2..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/check_macros.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ - -#include -#include -#include - -namespace ruy { -namespace check_macros { - -constexpr int kValueBufSize = 32; - -template -struct ToString { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "(?)"); - } -}; - -template <> -struct ToString { - static void Run(float value, char* buf) { - snprintf(buf, kValueBufSize, "%.9g", static_cast(value)); - } -}; - -template <> -struct ToString { - static void Run(double value, char* buf) { - snprintf(buf, kValueBufSize, "%.16g", value); - } -}; - -template -struct ToString::value>::type> { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "%lld", static_cast(value)); - } -}; - -template -struct ToString { - static void Run(T* value, char* buf) { - snprintf(buf, kValueBufSize, "%p", value); - } -}; - -template -struct ToString::value>::type> { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "(enum value %d)", static_cast(value)); - } -}; - -inline void Failure(const char* file, int line, const char* macro, - const char* condition) { - fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line, macro, - condition); - abort(); -} - -template -inline void Failure(const char* file, int line, const char* macro, - const char* lhs, const LhsType& lhs_value, const char* op, - const char* rhs, const RhsType& rhs_value) { - char lhs_value_buf[kValueBufSize]; - ToString::Run(lhs_value, lhs_value_buf); - char rhs_value_buf[kValueBufSize]; - ToString::Run(rhs_value, rhs_value_buf); - fprintf(stderr, - "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ " - "%s %s %s ].\n", - file, line, macro, lhs, op, rhs, lhs_value_buf, op, rhs_value_buf); - abort(); -} - -#define RUY_CHECK_IMPL(macro, condition) \ - do { \ - if (!(condition)) { \ - ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #condition); \ - } \ - } while (false) - -#define RUY_CHECK_OP_IMPL(macro, lhs, op, rhs) \ - do { \ - const auto& lhs_value = (lhs); \ - const auto& rhs_value = (rhs); \ - if (!(lhs_value op rhs_value)) { \ - ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #lhs, lhs_value, \ - #op, #rhs, rhs_value); \ - } \ - } while (false) - -#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition) -#define RUY_CHECK_EQ(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, y) -#define RUY_CHECK_NE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, y) -#define RUY_CHECK_GE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, y) -#define RUY_CHECK_GT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, y) -#define RUY_CHECK_LE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, y) -#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, y) - -#ifdef NDEBUG -#define RUY_DCHECK(condition) -#define RUY_DCHECK_EQ(x, y) -#define RUY_DCHECK_NE(x, y) -#define RUY_DCHECK_GE(x, y) -#define RUY_DCHECK_GT(x, y) -#define RUY_DCHECK_LE(x, y) -#define RUY_DCHECK_LT(x, y) -#else -#define RUY_DCHECK(condition) RUY_CHECK(condition) -#define RUY_DCHECK_EQ(x, y) RUY_CHECK_EQ(x, y) -#define RUY_DCHECK_NE(x, y) RUY_CHECK_NE(x, y) -#define RUY_DCHECK_GE(x, y) RUY_CHECK_GE(x, y) -#define RUY_DCHECK_GT(x, y) RUY_CHECK_GT(x, y) -#define RUY_DCHECK_LE(x, y) RUY_CHECK_LE(x, y) -#define RUY_DCHECK_LT(x, y) RUY_CHECK_LT(x, y) -#endif - -} // end namespace check_macros -} // end namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc b/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc deleted file mode 100644 index 1a2a5a238f2..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/check_macros_test.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -#include - -namespace { - -#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \ - do { \ - if (vacuously_succeeds || (condition)) { \ - RUY_##family(condition); \ - } \ - } while (false) - -#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \ - y) \ - do { \ - if (vacuously_succeeds || ((x)op(y))) { \ - RUY_##family##_##op_name(x, y); \ - } \ - } while (false) - -#ifdef NDEBUG -#define TEST_CONDITION(condition) \ - do { \ - TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ - } while (false) -#define TEST_COMPARISON(op_name, x, op, y) \ - do { \ - TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ - } while (false) -#else -#define TEST_CONDITION(condition) \ - do { \ - TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ - TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \ - } while (false) -#define TEST_COMPARISON(op_name, x, op, y) \ - do { \ - TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ - TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \ - } while (false) - -#endif - -template -void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) { - RUY_CHECK_EQ(lhs, lhs); - TEST_COMPARISON(EQ, lhs, ==, lhs); - RUY_CHECK_EQ(lhs, lhs); - RUY_CHECK_EQ(lhs, lhs); - if (lhs == rhs) { - RUY_CHECK_EQ(lhs, rhs); - } - if (lhs != rhs) { - RUY_CHECK_NE(lhs, rhs); - } -} - -template -void TestComparisons(const LhsType& lhs, const RhsType& rhs) { - TestEqualityComparisons(lhs, rhs); - if (lhs > rhs) { - RUY_CHECK_GT(lhs, rhs); - } - if (lhs >= rhs) { - RUY_CHECK_GE(lhs, rhs); - } - if (lhs < rhs) { - RUY_CHECK_LT(lhs, rhs); - } - if (lhs <= rhs) { - RUY_CHECK_LE(lhs, rhs); - } -} - -TEST(CheckMacrosTest, IntInt) { - TestComparisons(0, 0); - TestComparisons(0, 1); - TestComparisons(1, -1); - TestComparisons(-1, 0); - TestComparisons(123, -456); - TestComparisons(std::numeric_limits::min(), - std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::min()); -} - -TEST(CheckMacrosTest, Uint8Uint8) { - TestComparisons(0, 0); - TestComparisons(255, 0); - TestComparisons(0, 255); - TestComparisons(12, 34); -} - -TEST(CheckMacrosTest, Uint8Int) { - TestComparisons(0, std::numeric_limits::min()); - TestComparisons(255, std::numeric_limits::min()); - TestComparisons(0, std::numeric_limits::max()); - TestComparisons(255, std::numeric_limits::max()); -} - -TEST(CheckMacrosTest, FloatFloat) { - TestComparisons(0.f, 0.f); - TestComparisons(0.f, 1.f); - TestComparisons(1.f, -1.f); - TestComparisons(-1.f, 0.f); - TestComparisons(123.f, -456.f); - TestComparisons(std::numeric_limits::lowest(), - std::numeric_limits::max()); - TestComparisons(123.f, std::numeric_limits::max()); - TestComparisons(123.f, std::numeric_limits::lowest()); -} - -TEST(CheckMacrosTest, IntFloat) { - TestComparisons(0, 0.f); - TestComparisons(0, 1.f); - TestComparisons(1, -1.f); - TestComparisons(-1, 0.f); - TestComparisons(123, -456.f); - TestComparisons(std::numeric_limits::lowest(), - std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::lowest()); -} - -TEST(CheckMacrosTest, EnumClass) { - enum class SomeEnumClass { kA, kB, kC }; - TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA); - TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB); - TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB); -} - -} // namespace - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/common.h b/tensorflow/lite/experimental/ruy/ruy/common.h deleted file mode 100644 index e52a6ba6976..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/common.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Miscellaneous helpers internal library. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_LOAD) -#define RUY_PREFETCH_LOAD(X) X -#else -#define RUY_PREFETCH_LOAD(X) -#endif - -#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_STORE) -#define RUY_PREFETCH_STORE(X) X -#else -#define RUY_PREFETCH_STORE(X) -#endif - -#define RUY_STR(s) RUY_STR_UNEXPANDED(s) -#define RUY_STR_UNEXPANDED(s) #s - -namespace ruy { - -// Helper for type-erasing a pointer. -// -// Often inside Ruy, a template parameter holds type information statically, but -// we would like to have a function signature that doesn't depend on the -// template parameters, so that we can dispatch indirectly across multiple -// implementations. This helper is at the core of such type-erasure. -// -// The opposite of this operation is just `static_cast(void_ptr)`. -template -void* ToVoidPtr(T* p) { - return const_cast(static_cast(p)); -} - -template -Scalar SymmetricZeroPoint() { - if (std::is_floating_point::value) { - return 0; - } - if (std::is_signed::value) { - return 0; - } - return std::numeric_limits::max() / 2 + 1; -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/context.cc b/tensorflow/lite/experimental/ruy/ruy/context.cc deleted file mode 100644 index e0d4701645f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/context.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" -#include "tensorflow/lite/experimental/ruy/ruy/detect_x86.h" -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { - -void Context::SetRuntimeEnabledPaths(Path paths) { - runtime_enabled_paths_ = paths; -} - -Path Context::GetRuntimeEnabledPaths() { - // This function should always return the same value on a given machine. - // When runtime_enabled_paths_ has its initial value kNone, it performs - // some platform detection to resolve it to specific Path values. - - // Fast path: already resolved. - if (runtime_enabled_paths_ != Path::kNone) { - return runtime_enabled_paths_; - } - - // Need to resolve now. Start by considering all paths enabled. - runtime_enabled_paths_ = kAllPaths; - - // This mechanism is intended to be used for testing and benchmarking. For - // example, one can set RUY_FORCE_DISABLE_PATHS to Path::kAvx512 in order to - // evaluate AVX2 performance on an AVX-512 machine. -#ifdef RUY_FORCE_DISABLE_PATHS - runtime_enabled_paths_ = runtime_enabled_paths_ & ~(RUY_FORCE_DISABLE_PATHS); -#endif - -#if RUY_PLATFORM(ARM) - // Now selectively disable paths that aren't supported on this machine. - if ((runtime_enabled_paths_ & Path::kNeonDotprod) != Path::kNone) { - if (!DetectDotprod()) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kNeonDotprod; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kNeonDotprod) == Path::kNone); - } - } -#endif // RUY_PLATFORM(ARM) - -#if RUY_PLATFORM(X86) - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. Optimization is not finished. In particular the dimensions of - // the kernel blocks can be changed as desired. - // - if ((runtime_enabled_paths_ & Path::kSse42) != Path::kNone) { - if (!(HaveBuiltPathForSse42() && DetectCpuSse42())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kSse42; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kSse42) == Path::kNone); - } - } - - if ((runtime_enabled_paths_ & Path::kAvx2) != Path::kNone) { - if (!(HaveBuiltPathForAvx2() && DetectCpuAvx2())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx2; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx2) == Path::kNone); - } - } - - if ((runtime_enabled_paths_ & Path::kAvx512) != Path::kNone) { - if (!(HaveBuiltPathForAvx512() && DetectCpuAvx512())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx512; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx512) == Path::kNone); - } - } - - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. Optimization is not finished. In particular the dimensions of - // the kernel blocks can be changed as desired. - // - if ((runtime_enabled_paths_ & Path::kAvxVnni) != Path::kNone) { - if (!(HaveBuiltPathForAvxVnni() && DetectCpuAvxVnni())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvxVnni; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvxVnni) == Path::kNone); - } - } -#endif // RUY_PLATFORM(X86) - - // Sanity check. We can't possibly have disabled all paths, as some paths - // are universally available (kReference, kStandardCpp). - RUY_DCHECK_NE(runtime_enabled_paths_, Path::kNone); - return runtime_enabled_paths_; -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/context.h b/tensorflow/lite/experimental/ruy/ruy/context.h deleted file mode 100644 index a2d05a9ba5c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/context.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" -#include "tensorflow/lite/experimental/ruy/ruy/trace.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -// The state private to each Ruy thread. -struct PerThreadState { - // Each thread may be running on a different microarchitecture. For example, - // some threads may be on big cores, while others are on little cores. Thus, - // it's best for the tuning to be per-thread. - TuningResolver tuning_resolver; - // Each thread has its own local allocator. - Allocator allocator; -}; - -// A Context holds runtime information used by Ruy. It holds runtime resources -// such as the workers thread pool and the allocator (which holds buffers for -// temporary data), as well as runtime options controlling which Paths are -// enabled (typically based on which instruction sets are detected) and how -// many threads to use. -struct Context final { - Path last_taken_path = Path::kNone; - Tuning explicit_tuning = Tuning::kAuto; - // TODO(benoitjacob) rename that thread_pool. Current name is gemmlowp legacy. - ThreadPool workers_pool; - int max_num_threads = 1; - // State for each thread in the thread pool. Entry 0 is the main thread. - std::vector> per_thread_states; - TracingContext tracing; - CachePolicy cache_policy = CachePolicy::kNoCache; - - Allocator* GetMainAllocator() { - if (!main_allocator_) { - main_allocator_.reset(new Allocator); - } - return main_allocator_.get(); - } - - PrepackedCache* GetPrepackedCache() { - if (!prepacked_cache_) { - prepacked_cache_.reset(new PrepackedCache); - } - return prepacked_cache_.get(); - } - - void ClearPrepackedCache() { prepacked_cache_ = nullptr; } - - void EnsureNPerThreadStates(int thread_count) { - while (per_thread_states.size() < static_cast(thread_count)) { - per_thread_states.emplace_back(new PerThreadState); - } - } - - Tuning GetMainThreadTuning() { - EnsureNPerThreadStates(1); - TuningResolver* tuning_resolver = &per_thread_states[0]->tuning_resolver; - tuning_resolver->SetTuning(explicit_tuning); - return tuning_resolver->Resolve(); - } - - template - Path GetPathToTake() { - last_taken_path = - GetMostSignificantPath(CompiledPaths & GetRuntimeEnabledPaths()); - return last_taken_path; - } - - void SetRuntimeEnabledPaths(Path paths); - Path GetRuntimeEnabledPaths(); - - private: - // Allocator for main thread work before invoking the threadpool. - // Our simple Allocator does not allow reserving/allocating more blocks - // while it's already in committed state, so the main thread needs both - // this allocator, and its per-thread allocator. - std::unique_ptr main_allocator_; - std::unique_ptr prepacked_cache_; - Path runtime_enabled_paths_ = Path::kNone; -}; - -} // end namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/context_test.cc b/tensorflow/lite/experimental/ruy/ruy/context_test.cc deleted file mode 100644 index bddbfcf8c55..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/context_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" - -#include -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { -namespace { - -TEST(ContextTest, EnabledPathsGeneral) { - ruy::Context ruy_context; - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - const auto ruy_paths_repeat = ruy_context.GetRuntimeEnabledPaths(); - ASSERT_EQ(ruy_paths, ruy_paths_repeat); - EXPECT_NE(ruy_paths, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kReference); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp); -} - -#if RUY_PLATFORM(X86) -TEST(ContextTest, EnabledPathsX86) { - ruy::Context ruy_context; - ruy_context.SetRuntimeEnabledPaths(Path::kSse42 | Path::kAvx2 | - Path::kAvx512 | Path::kAvxVnni); - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); -} -#endif // RUY_PLATFORM(X86) - -#if RUY_PLATFORM(ARM) -TEST(ContextTest, EnabledPathsArm) { - ruy::Context ruy_context; - ruy_context.SetRuntimeEnabledPaths(Path::kNeon | Path::kNeonDotprod); - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon); -} -#endif // RUY_PLATFORM(ARM) - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h b/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h deleted file mode 100644 index 95ed35ec097..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { - -// LocalDataCacheSize returns a sane default size for each CPU core's local -// data cache, i.e. the largest data cache that is local to that CPU core, not -// shared with other cores. That allows coarse tuning of code that aims for -// most of its memory accesses to hit such a typically fast data cache. -// -// SharedDataCacheSize returns a sane default size of the total data cache -// accessible to each CPU, including any shared cache. -// -// For example, if we design tune this code for a ARM Cortex-A55 with a local L1 -// cache of 32k, a local L2 cache of 128k and a shared L3 cache of 1M, -// LocalDataCacheSize should return 128k and SharedDataCacheSize -// should return 1M. -// -// Ideally these values would be queried at runtime, and we should probably -// do that on x86, but that is hard to do on ARM. -#if RUY_PLATFORM(ARM_64) -inline int LocalDataCacheSize() { return 1 << 15; } -inline int SharedDataCacheSize() { return 1 << 19; } -#elif RUY_PLATFORM(ARM_32) -inline int LocalDataCacheSize() { return 1 << 14; } -inline int SharedDataCacheSize() { return 1 << 18; } -#elif RUY_PLATFORM(X86) -inline int LocalDataCacheSize() { return 1 << 17; } -inline int SharedDataCacheSize() { return 1 << 21; } -#else -inline int LocalDataCacheSize() { return 1 << 14; } -inline int SharedDataCacheSize() { return 1 << 18; } -#endif -// Variants taking a Path argument which acts -// as a hint telling whether we're targeting more or less recent/powerful CPUs. -inline int LocalDataCacheSize(Path path) { -#if RUY_PLATFORM(ARM_64) - if (path == Path::kNeonDotprod) { - // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with - // 128k L2 local cache. - return 1 << 17; - } -#else - (void)path; -#endif - return LocalDataCacheSize(); -} -inline int SharedDataCacheSize(Path path) { -#if RUY_PLATFORM(ARM_64) - if (path == Path::kNeonDotprod) { - // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with - // 1M L3 shared cache. - return 1 << 20; - } -#else - (void)path; -#endif - return SharedDataCacheSize(); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc b/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc deleted file mode 100644 index 8f6d2c9f9fe..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/detect_arm.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Detection of dotprod instructions on ARM. - * The current Linux-specific code relies on sufficiently new Linux kernels: - * At least Linux 4.15 in general; on Android, at least Linux 4.14.111 thanks to - * a late backport. This was backported just before the Android 10 release, so - * this is leaving out pre-release Android 10 builds as well as earlier Android - * versions. - * - * It is possible to detect instructions in other ways that don't rely on - * an OS-provided feature identification mechanism: - * - * (A) We used to have a SIGILL-handler-based method that worked at least - * on Linux. Its downsides were (1) crashes on a few devices where - * signal handler installation didn't work as intended; (2) additional - * complexity to generalize to other Unix-ish operating systems including - * iOS; (3) source code complexity and fragility of anything installing - * and restoring signal handlers; (4) confusing behavior under a debugger. - * - * (B) We also experimented with a fork-ing approach where a subprocess - * tries the instruction. Compared to (A), this is much simpler and more - * reliable and portable, but also much higher latency on Android where - * an uncaught signal typically causes a 100 ms latency. - * - * Should there be interest in either technique again in the future, - * code implementing both (A) and (B) can be found in earlier revisions of this - * file - in actual code for (A) and in a comment for (B). - */ - -#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" - -#if defined __linux__ && defined __aarch64__ -#include -#endif - -namespace ruy { - -namespace { - -#if defined __linux__ && defined __aarch64__ -bool DetectDotprodByLinuxAuxvMethod() { - // This is the value of HWCAP_ASIMDDP in sufficiently recent Linux headers, - // however we need to support building against older headers for the time - // being. - const int kLocalHwcapAsimddp = 1 << 20; - return getauxval(AT_HWCAP) & kLocalHwcapAsimddp; -} -#endif - -} // namespace - -bool DetectDotprod() { -#if defined __linux__ && defined __aarch64__ - return DetectDotprodByLinuxAuxvMethod(); -#endif - - return false; -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_arm.h b/tensorflow/lite/experimental/ruy/ruy/detect_arm.h deleted file mode 100644 index 9a1542d3cce..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/detect_arm.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Temporary dotprod-detection code until we can rely on getauxval. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ - -namespace ruy { - -// On A64, returns true if the dotprod extension is present. -// On other architectures, returns false unconditionally. -bool DetectDotprod(); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc b/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc deleted file mode 100644 index 113a73c09e3..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/detect_x86.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/detect_x86.h" - -#include - -#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) -#include // IWYU pragma: keep - -#endif - -namespace ruy { -#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) - -namespace { - -// See Intel docs, such as http://goo.gl/c6IkGX. -inline void RunCpuid(std::uint32_t eax, std::uint32_t ecx, - std::uint32_t abcd[4]) { - std::uint32_t ebx, edx; -#if defined(__i386__) && defined(__PIC__) - /* in case of PIC under 32-bit EBX cannot be clobbered */ - asm volatile("movl %%ebx, %%edi \n\t cpuid \n\t xchgl %%ebx, %%edi" - : "=D"(ebx), -#else - asm volatile("cpuid" - : "+b"(ebx), -#endif - "+a"(eax), "+c"(ecx), "=d"(edx)); - abcd[0] = eax; - abcd[1] = ebx; - abcd[2] = ecx; - abcd[3] = edx; -} - -} // namespace - -bool DetectCpuSse42() { - std::uint32_t abcd[4]; - - constexpr std::uint32_t kEcxSse42 = 1u << 20; - RunCpuid(1, 0, abcd); - const bool has_sse4_2_base = (abcd[2] & kEcxSse42) == kEcxSse42; - -#ifdef RUY_ENABLE_AMD_CPUID_CHECKS - constexpr std::uint32_t kEcxAbm = 1u << 5; - RunCpuid(0x80000001, 0, abcd); - const bool has_extras = (abcd[2] & kEcxAbm) == kEcxAbm; -#else - constexpr std::uint32_t kEcxPopcnt = 1u << 23; - RunCpuid(1, 0, abcd); - const bool has_extras = (abcd[2] & kEcxPopcnt) == kEcxPopcnt; -#endif - - return has_sse4_2_base && has_extras; -} - -bool DetectCpuAvx2() { - constexpr std::uint32_t kEbxAvx2 = 1u << 5; - constexpr std::uint32_t kEcxFma = 1u << 12; - - std::uint32_t abcd[4]; - - RunCpuid(7, 0, abcd); - const bool has_avx2 = (abcd[1] & kEbxAvx2) == kEbxAvx2; - RunCpuid(1, 0, abcd); - const bool has_fma = (abcd[2] & kEcxFma) == kEcxFma; - - return has_avx2 && has_fma; -} - -bool DetectCpuAvx512() { - constexpr std::uint32_t kEbxAvx512F = 1u << 16; - constexpr std::uint32_t kEbxAvx512Dq = 1u << 17; - constexpr std::uint32_t kEbxAvx512Cd = 1u << 28; - constexpr std::uint32_t kEbxAvx512Bw = 1u << 30; - constexpr std::uint32_t kEbxAvx512Vl = 1u << 31; - - constexpr std::uint32_t kEbxAvx512Mask = - kEbxAvx512F | kEbxAvx512Dq | kEbxAvx512Cd | kEbxAvx512Bw | kEbxAvx512Vl; - std::uint32_t abcd[4]; - RunCpuid(7, 0, abcd); - - return (abcd[1] & kEbxAvx512Mask) == kEbxAvx512Mask; -} - -#endif -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/detect_x86.h b/tensorflow/lite/experimental/ruy/ruy/detect_x86.h deleted file mode 100644 index 185dabe06a5..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/detect_x86.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -#if RUY_PLATFORM(X86_ENHANCEMENTS) - -// This also checks ABM support, which implies LZCNT and POPCNT. -bool DetectCpuSse42(); -bool DetectCpuAvx2(); -bool DetectCpuAvx512(); -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// TODO(b/146646451): Introduce and activate. -inline bool DetectCpuAvxVnni() { return false; } - -#else // RUY_PLATFORM(X86_ENHANCEMENTS) - -inline bool DetectCpuSse42() { return false; } -inline bool DetectCpuAvx2() { return false; } -inline bool DetectCpuAvx512() { return false; } -inline bool DetectCpuAvxVnni() { return false; } - -#endif // !RUY_PLATFORM(X86_ENHANCEMENTS) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/dispatch.h b/tensorflow/lite/experimental/ruy/ruy/dispatch.h deleted file mode 100644 index d1e97e29b9c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/dispatch.h +++ /dev/null @@ -1,482 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements the translation between Ruy's entry point (ruy::Mul) and -// the internal implementation of matrix multiplication. -// -// The primary elements of this dispatch are: -// - pick suitable gemm kernel and packing routines for the user-specified -// CompiledPaths based on the current CPU. -// - decide on the structure of the packed matrices needed by the internal -// implementation (see pack.h for more information on packing). -// - translate the Mul operation into TrMul (see trmul.h for why that is -// useful). This is done by changing the matrix Layout -- no matrix data is -// actually moved. -// -// This file is also factored to serve as a building block for the advanced API -// as well. -// -// This file also performs some checking of invariants to catch user errors. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ - -#include -#include -#include // IWYU pragma: keep -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" - -namespace ruy { - -// If the Spec's LayoutSupport covers only some special cases, -// this function enforces that the matrix multiplication at hand falls into -// that special case. -template -void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout, - const Layout& dst_layout) { - if (Spec::kLayoutSupport == LayoutSupport::kRCC) { - RUY_DCHECK(IsRowMajor(lhs_layout)); - RUY_DCHECK(IsColMajor(rhs_layout)); - RUY_DCHECK(IsColMajor(dst_layout)); - } -} - -template -bool IsSymmetricZeroPoint(Scalar zero_point) { - return zero_point == SymmetricZeroPoint(); -} - -template -void CheckZeroPoint(Scalar zero_point) { - if (std::is_floating_point::value || - Spec::kZeroPointSupport == ZeroPointSupport::kSymmetric) { - RUY_DCHECK(IsSymmetricZeroPoint(zero_point)); - } -} - -template -void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, - DstScalar dst_zero_point) { - // If the Spec's ZeroPointSupport covers only some special cases, - // this function enforces that the matrix multiplication at hand falls into - // that special case. - CheckZeroPoint(lhs_zero_point); - CheckZeroPoint(rhs_zero_point); - CheckZeroPoint(dst_zero_point); - - // Guard against the case when both LHS and RHS zero_point's are equal to - // the minimum representable value. In that case, padding with zero_point - // values will generate the bad case for fast int8 kernels on NEON - // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8 - // into a int16: this is safe except in the bad case -128*-128 + -128*-128. - // See b/131609283. This only affects the kNeon path but we ban this for all - // paths in order for ruy to have the same supported parameter space - // on all paths. - RUY_DCHECK(lhs_zero_point != std::numeric_limits::lowest() || - rhs_zero_point != std::numeric_limits::lowest()); -} - -template -void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) { - static_assert(std::is_same::value, ""); - if (!std::is_same::value) return; - - // If user is looking for the raw accumulator, zero_point and all the other - // dequantize fields don't make sense and should not be set. - RUY_DCHECK_EQ(dst_zero_point, 0); - RUY_DCHECK_EQ(spec.clamp_max, std::numeric_limits::max()); - RUY_DCHECK_EQ(spec.clamp_min, std::numeric_limits::min()); - RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_DCHECK_EQ(spec.multiplier_exponent, 0); - RUY_DCHECK_EQ(spec.multiplier_fixedpoint_perchannel, nullptr); - RUY_DCHECK_EQ(spec.multiplier_exponent_perchannel, nullptr); -} - -inline bool IsColMajorTrMul(const TrMulParams& params) { - return IsColMajor(params.src[Side::kLhs].layout) && - IsColMajor(params.src[Side::kRhs].layout) && - IsColMajor(params.dst.layout); -} - -inline void CreatePackedLayout(const Layout& src, const Type& scalar, - const KernelLayout& kernel_layout, - PackedLayout* packed) { - packed->order = Order::kColMajor; - packed->rows = round_up_pot(src.rows, kernel_layout.rows); - packed->cols = round_up_pot(src.cols, kernel_layout.cols); - packed->kernel = kernel_layout; - int inner_size = packed->rows; - if (RUY_OPT_ENABLED(RUY_OPT_AVOID_ALIASING)) { - packed->stride = - (inner_size * scalar.size) % 1024 ? inner_size : inner_size + 64; - } else { - packed->stride = inner_size; - } -} - -template -void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, - TrMulParams* params) { - // Ruy always uses 32-bit signed accumulators for quantized - // matrix multiplication, so we would like to always use std::int32_t - // unconditionally for SumsType. - // However, for floating point types, we still need a reasonable type here to - // avoid tripping assertions elsewhere in the code. - using SumsType = - typename std::conditional::value, Scalar, - std::int32_t>::type; - - const DMatrix& src = params->src[side]; - PMatrix* packed = ¶ms->packed[side]; - packed->data_type = Type::Create(); - packed->sums_type = Type::Create(); - CreatePackedLayout(src.layout, packed->data_type, kernel_layout, - &packed->layout); - packed->zero_point = Pack(src.zero_point); -} - -template -void PopulateTrMulParams(TrMulParams* params) { - static_assert((ThePath & Path::kReference) == Path::kNone, - "Path::kReference should not do TrMul"); - // The optimized code paths don't handle the full generality of Ruy's API. - // Fall back to Path::kStandardCpp if necessary. - bool fallback_to_standard_cpp = false; - if (ThePath != Path::kStandardCpp) { - // The optimized code paths currently only handle the case of all matrices - // being column major. - if (!IsColMajorTrMul(*params)) { - fallback_to_standard_cpp = true; - } - } - - if (fallback_to_standard_cpp) { - PopulateTrMulParams(params); - return; - } - - using PackedLhsScalar = PackedType; - using PackedRhsScalar = PackedType; - using Kernel = - Kernel; - using LhsKernelLayout = typename Kernel::LhsLayout; - using RhsKernelLayout = typename Kernel::RhsLayout; - - params->path = ThePath; - - params->local_data_cache_size = Spec::local_data_cache_size(); - params->shared_data_cache_size = Spec::shared_data_cache_size(); - - CreatePackedMatrix( - Side::kLhs, ToKernelLayout(), params); - CreatePackedMatrix( - Side::kRhs, ToKernelLayout(), params); - params->run_pack[Side::kLhs] = - &RunPack; - params->run_pack[Side::kRhs] = - &RunPack; - params->run_kernel = - &RunKernel; - - return; -} - -// PopulateTrMulParamsAllCompiledPaths calls into one of multiple -// instantiations of PopulateTrMulParams. For each bit that is set in -// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path -// corresponding to that single bit. The call to PopulateTrMulParams is -// guarded by a runtime check that it is in fact the dynamically selected path. -// -// PopulateTrMulParamsAllCompiledPaths is implemented with template -// metaprogramming by mutual recursion between PathSearchCountdown and -// PathSearchCompiledPaths. -// -// PopulateTrMulParamsAllCompiledPaths is logically implementing the following -// computation: -// -// template -// void PopulateTrMulParamsAllCompiledPaths(Path the_path, -// TrMulParams* params) { -// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] -// Path current_path = static_cast(1 << bit); -// if ((CompiledPaths & current_path) != Path::kNone) { // [2] -// if (current_path == the_path) { // [3] -// PopulateTrMulParams(the_path, params); -// return; -// } -// } -// } -// } -// -// -// -// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is -// done in the recursion of PathSearchOnlyCompiledPaths. -// [2] - Done by PathSearchOnlyCompiledPaths's partial template -// specialization on InCompiledPaths. This is the check which necessitates -// doing the whole computation at C++ compile time. -// [3] - Done by the `if` in the main definition of -// PathSearchOnlyCompiledPaths. -// -// The template metaprogramming is necessary because: -// - In `PopulateTrMulParams`, current_path must be a C++ -// compile-time constant. -// - PopulateTrMulParamsAllCompiledPaths must not instantiate -// inner loops for paths that are not in CompiledPaths, since that can result in -// bogus instantiations which cause a compile time failure. -template -struct PathSearchCountdown; - -template -struct PathSearchOnlyCompiledPaths { - static constexpr Path kCurrentPath = static_cast(1 << BitNumber); - static void Search(Path the_path, TrMulParams* params) { - if (kCurrentPath == the_path) { - PopulateTrMulParams( - params); - return; - } - PathSearchCountdown::Search(the_path, params); - } -}; - -// Skip this iteration if CompiledPaths doesn't contain the specified path. -template -struct PathSearchOnlyCompiledPaths { - static void Search(Path the_path, TrMulParams* params) { - PathSearchCountdown::Search(the_path, params); - } -}; - -template -struct PathSearchCountdown { - static constexpr Path kCurrentPath = static_cast(1 << BitNumber); - static void Search(Path the_path, TrMulParams* params) { - PathSearchOnlyCompiledPaths< - CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, - LhsScalar, RhsScalar, DstScalar, Spec>::Search(the_path, params); - } -}; - -// Termination of the countdown. If the counter reaches -1, then we haven't -// found the specified path. -template -struct PathSearchCountdown { - static void Search(Path the_path, TrMulParams* params) { RUY_DCHECK(false); } -}; - -template -void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { - return PathSearchCountdown::Search(the_path, - params); -} - -template -void CreateTrMulParams(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, Path the_path, - TrMulParams* params) { - // Fill in the fields we already know. - params->src[Side::kLhs] = ToDMatrix(lhs); - params->src[Side::kRhs] = ToDMatrix(rhs); - params->dst = ToDMatrix(*dst); - params->spec = ToVoidPtr(&spec); - - // Create inner loops and packed matrices based on the Path. - PopulateTrMulParamsAllCompiledPaths(the_path, params); -} - -template -void ReferenceMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - profiler::ScopeLabel label("ReferenceMul"); - for (int i = 0; i < lhs.layout.rows; i++) { - for (int j = 0; j < rhs.layout.cols; j++) { - using AccumScalar = typename Spec::AccumScalar; - AccumScalar accum = 0; - for (int k = 0; k < lhs.layout.cols; k++) { - AccumScalar lhs_val = Element(lhs, i, k); - AccumScalar rhs_val = Element(rhs, k, j); - accum += (lhs_val - lhs.zero_point) * (rhs_val - rhs.zero_point); - } - if (spec.bias) { - accum += spec.bias[i]; - } - ApplyMultiplier(spec, i, &accum); - accum += dst->zero_point; - accum = std::min(accum, spec.clamp_max); - accum = std::max(accum, spec.clamp_min); - *ElementPtr(dst, i, j) = static_cast(accum); - } - } -} - -// Compile-time dispatch to ReferenceMul. This allows us to statically ensure -// that there is no call to ReferenceMul in the user's binary. -template -struct CompileTimeEnabledReferenceMul { - template - static void Run(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - ReferenceMul(lhs, rhs, spec, dst); - } -}; - -// When this partial specialization is chosen, it ensures that ReferenceMul -// is never compiled. -template <> -struct CompileTimeEnabledReferenceMul { - template - static void Run(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - RUY_DCHECK(false); - } -}; - -inline void HandlePrepackedCaching(TrMulParams* params, - const SidePair& cacheable, - Context* context) { - if (context->cache_policy == CachePolicy::kNoCache) { - return; - } - - if (context->cache_policy == CachePolicy::kCacheLHSOnNarrowMul) { - // TODO(b/149304278) Cache on dst.cols <= selected kernel width. - if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) { - return; - } - PrepackedCache* prepacked_cache = context->GetPrepackedCache(); - auto cache_key = std::make_pair(reinterpret_cast(params->run_kernel), - params->src[Side::kLhs].data); - auto it = prepacked_cache->FindAndUpdate(cache_key); - if (it != prepacked_cache->cend()) { - params->packed[Side::kLhs].data = it->second.first.data; - params->packed[Side::kLhs].sums = it->second.first.sums; - params->is_prepacked[Side::kLhs] = true; - return; - } - - // Allocate the prepacked matrix. - PrepackedMatrix prepacked_lhs; - prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]); - prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]); - prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs); - params->packed[Side::kLhs].data = prepacked_lhs.data; - params->packed[Side::kLhs].sums = prepacked_lhs.sums; - params->is_prepacked[Side::kLhs] = true; - Tuning tuning = context->GetMainThreadTuning(); - params->RunPack(Side::kLhs, tuning, 0, - params->packed[Side::kLhs].layout.cols); - prepacked_cache->Insert(cache_key, prepacked_lhs); - return; - } -} - -template -void DispatchMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst) { - static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path"); - static_assert((CompiledPaths & ~kAllPaths) == Path::kNone, - "CompiledPaths must be a subset of ruy::kAllPaths"); - - profiler::ScopeLabel mul_label("Mul"); - profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d", - lhs.layout.rows, lhs.layout.cols, - rhs.layout.cols); - - EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); - EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, - dst->zero_point); - EnforceDstSpecSupport(spec, dst->zero_point); - - // This should be a constant, for a given machine and CompiledPaths. - // There is a back door to override it for testing, but in production it will - // always be the "best" Path. I.e. the one with the newest SIMD instructions - // available on the present machine, and avoiding Path::kReference unless - // no other path is compiled. - // - // Unfortunately, it is not a *static* constant, since it depends on runtime - // detection of the available SIMD instructions. - Path the_path = context->GetPathToTake(); - - // Production code should probably never execute Path::kReference. - // Path::kReference implements a Mul, not a TrMul like the rest of Ruy, so if - // that's what we need to do, then get it out of the way before going down the - // TrMul path. - if (the_path == Path::kReference) { - constexpr bool ReferenceMulIsEnabled = - (CompiledPaths & Path::kReference) != Path::kNone; - CompileTimeEnabledReferenceMul::Run(lhs, rhs, spec, - dst); - return; - } - - // As described in the comment at the top of this file, Ruy internally - // converts Mul into TrMul. We handle that here. - // - // This is Ruy's main code path. - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - SidePair cacheable(lhs.cacheable, rhs.cacheable); - HandlePrepackedCaching(¶ms, cacheable, context); - TrMul(¶ms, context); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/example.cc b/tensorflow/lite/experimental/ruy/ruy/example.cc deleted file mode 100644 index 5d31d6c2e3e..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/example.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" - -void ExampleMulFloat(ruy::Context *context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, float:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - const float bias_data[] = {1, 0}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - spec.bias = bias_data; - spec.clamp_min = 0; - spec.clamp_max = 15; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, float with bias addition and clamp:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) { - const std::uint8_t lhs_data[] = {124, 125, 126, 127}; - const std::uint8_t rhs_data[] = {129, 130, 131, 132}; - std::uint8_t dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - lhs.zero_point = 125; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - rhs.zero_point = 132; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - dst.zero_point = 129; - - ruy::BasicSpec spec; - spec.multiplier_fixedpoint = 1 << 30; - - spec.multiplier_exponent = 0; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} -void ExampleMulInt8PerChannelQuantized(ruy::Context *context) { - const std::int8_t lhs_data[] = {1, 2, 3, 4}; - const std::int8_t rhs_data[] = {1, 2, 3, 4}; - const std::int32_t multiplier_data[] = {3 << 28, 5 << 28}; - const int exponent_data[] = {1, -2}; - std::int8_t dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - spec.multiplier_fixedpoint_perchannel = multiplier_data; - spec.multiplier_exponent_perchannel = exponent_data; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, int8 quantized with per-channel multipliers\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -int main() { - ruy::Context context; - ExampleMulFloat(&context); - ExampleMulFloatWithBiasAddAndClamp(&context); - ExampleMulUint8AsymmetricQuantized(&context); - ExampleMulInt8PerChannelQuantized(&context); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc b/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc deleted file mode 100644 index 9e1dd17f86d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/example_advanced.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h" - -// Simple allocator for allocating pre-packed matrices. -class SimpleAllocator { - public: - void* AllocateBytes(std::size_t num_bytes) { - char* p = new char[num_bytes]; - buffers_.emplace_back(p); - return static_cast(p); - } - - private: - std::vector> buffers_; -}; - -void ExamplePrepack(ruy::Context* context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - float dst_data[4]; - - // Set up the matrix layouts and spec. - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - ruy::BasicSpec spec; - - SimpleAllocator allocator; - auto alloc_fn = [&allocator](std::size_t num_bytes) -> void* { - return allocator.AllocateBytes(num_bytes); - }; - - // In this example, we pre-pack only the RHS, but either will work. - // Note that we only need to set the data pointer for the matrix we are - // pre-packing. - ruy::PrepackedMatrix prepacked_rhs; - rhs.data = rhs_data; - ruy::PrePackForMul(lhs, rhs, spec, context, &dst, - /*prepacked_lhs=*/nullptr, &prepacked_rhs, - alloc_fn); - - // No data will be read from the RHS input matrix when using a pre-packed RHS. - rhs.data = nullptr; - lhs.data = lhs_data; - dst.data = dst_data; - ruy::MulWithPrepacked(lhs, rhs, spec, context, &dst, - /*prepacked_lhs=*/nullptr, - &prepacked_rhs); - rhs.data = rhs_data; - - // Print out the results. - std::cout << "Example Mul with pre-packing RHS, float:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -int main() { - ruy::Context context; - ExamplePrepack(&context); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h deleted file mode 100644 index 08651facb7e..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -bool HaveBuiltPathForSse42(); -bool HaveBuiltPathForAvx2(); -bool HaveBuiltPathForAvx512(); -bool HaveBuiltPathForAvxVnni(); -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc deleted file mode 100644 index a9bcfbbbcfb..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx2.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvx2() { return false; } - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -bool HaveBuiltPathForAvx2() { return true; } - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc deleted file mode 100644 index 2b42cba26c9..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avx512.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvx512() { return false; } - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -bool HaveBuiltPathForAvx512() { return true; } - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc deleted file mode 100644 index 42f9cb668df..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_avxvnni.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvxVnni() { return false; } - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -bool HaveBuiltPathForAvxVnni() { return true; } - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc deleted file mode 100644 index e7470f54520..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/have_built_path_for_sse42.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/have_built_path_for.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForSse42() { return false; } - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -bool HaveBuiltPathForSse42() { return true; } - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h b/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h deleted file mode 100644 index cf10adf084d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/internal_matrix.h +++ /dev/null @@ -1,388 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Internal types and helpers for matrices. -// -// Ruy has a couple slightly different notions of matrices, besides the -// Matrix class that we expose to the user-facing API. -// -// TODO(silvasean): Put parts of this architecture description somewhere more -// prominent. -// -// The 4 main matrix types are: -// - Matrix: This is a user-facing type on Ruy's external API boundary. It is -// also used internally. -// - DMatrix: This is a type-erased version of Matrix. "D" = "dynamic". -// - PMatrix: This represents a packed matrix, which requires tracking kernel -// layout and row/column sums for quantization. It is type-erased. -// - PackedMatrix: This is a statically typed variant of PMatrix for -// convenience inside typed routines. -// -// Note that Matrix is *not* implemented in terms of the internal types. It -// is an independent, simple, and user-facing type. -// -// The use of type-erasure might seem surprising for a library like Ruy with a -// heavily-templated entry point, but it is motivated by the desire for most of -// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3 -// main parts: -// - "front-end" (dispatch.h) - this is the highly templated ruy::Mul entry -// point, along with routines that select RunKernel and RunPack implementations -// statically based on those template parameters. -// - "back-end" (kernel.h, pack.h)- this consists of the implementations of -// RunKernel and RunPack, often in assembly code, which are the building blocks -// that Ruy calls to perform matrix multiplication. These are templated so that -// only the requested types/Path's are actually emitted by the compiler. -// - "middle-end" (trmul.h) - this is the part of Ruy that orchestrates the -// calls to the "back-end" optimized building blocks. This layer has to deal -// with issues like cache locality and low-overhead multi-threading. -// -// There is a desire for the "middle-end" to be non-templated in order to -// simplify the implementation and reduce code-size. We type-erase when going -// from the "front-end" to the "middle-end", and un-type-erase going from the -// "middle-end" to the "back-end". The un-type-erasure is possible because the -// "front-end" is responsible for instantiating the needed "back-end" templates, -// and thus the static type information is still present. -// -// Each layer of Ruy uses matrix types: -// - "front-end": Matrix -// - "middle-end": DMatrix, PMatrix -// - "back-end": Matrix, PackedMatrix -// -// The use of separate types for packed matrices is not essential, but makes it -// obvious at a glance whether a matrix is a packed matrix or not. We would -// reconsider this decision if there was significant duplication between packed -// and unpacked matrices, but that doesn't seem to be the case at the moment. -// -// Another goal is to keep the user-facing Matrix as simple and -// understandable as possible. Ideally, a user should be able to read the struct -// definition for Matrix and see a very simple definition with no internal -// details like sums and kernel block layout. -// -// To present another structured view of our various matrix types, here's a -// table: -// Plain matrices Packed matrices -// +---------------------------------- -// Templated | Matrix PackedMatrix -// Type-erased | DMatrix PMatrix -// -// -// There is 1 additional matrix type not mentioned above, due to its low -// importance: -// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare -// minimum of fields needed for representing the raw data and sums buffers of a -// packed matrix for the "advanced" explicit pre-packing API. This type plays no -// role in Ruy's internals and can generally by ignored. The only reason it -// exists is so that PMatrix is not exposed to users -- we prefer to keep the -// internal matrix types hidden, even from "advanced" users. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -// KernelLayout describes small-scale block structure in a packed matrix layout. -// It's a runtime (as opposed to compile-time-constant) version of the -// FixedKernelLayout struct used to declare kernel layouts. -// -// This is is sometimes known as "tiling" in other contexts. -// -// For example, consider a packed matrix in column-major format with a -// column-major KernelLayout. The matrix logically has a shape of -// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array -// of shape `[cols / kcols, rows / krows, kcols, krows]`. -// -// Note that in the case of kcols=1, krows=1, this degenerates to -// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block -// structure. -struct KernelLayout { - Order order = Order::kColMajor; - std::uint8_t rows = 1; - std::uint8_t cols = 1; -}; - -// A packed matrix has a small-scale block structure that is not present in in -// the input matrices. This block structure is necessary for the kernels to -// process data efficiently. -// -// This struct is very similar to Layout, but has the extra KernelLayout field. -struct PackedLayout { - std::int32_t rows = 0; - std::int32_t cols = 0; - // Stride is the offset between two adjacent matrix elements - // in the non-contiguous direction. - std::int32_t stride = 0; - Order order = Order::kColMajor; - // Small scale layout shuffling, potentially departing from - // linear row-major or column-major storage. See KernelLayout. - KernelLayout kernel; -}; - -// Dynamic representation for a type. -// -// The most important field in this struct is the size, which Ruy uses to know -// how much memory to allocate without having to be templated on a type. -// Signed-ness and floating-point-ness are mainly present as debugging checks. -// -// Note: Ruy does not use this struct to to dynamically dispatch between -// different typed implementations. As described in the comment at the top of -// this file, Ruy's "front-end", which is templated, instantiates all the -// necessary "back-end" routines with complete static knowledge of all the -// types. -struct Type { - template - static Type Create() { - Type ret; - ret.is_signed = std::is_signed::value; - ret.is_floating_point = std::is_floating_point::value; - ret.size = sizeof(T); - return ret; - } - - template - void AssertIs() const { - RUY_DCHECK_EQ(is_signed, Create().is_signed); - RUY_DCHECK_EQ(is_floating_point, Create().is_floating_point); - RUY_DCHECK_EQ(size, Create().size); - } - - bool is_signed = false; - bool is_floating_point = false; - std::uint8_t size = 0; -}; - -// Type-erased matrix. -struct DMatrix { - Type data_type; - void* data = nullptr; - Layout layout; - std::int32_t zero_point = 0; -}; - -// Type-erased packed matrix. -struct PMatrix { - Type data_type; - void* data = nullptr; - Type sums_type; - void* sums = nullptr; - PackedLayout layout; - std::int32_t zero_point = 0; -}; - -// Convenient typed helper for packed matrices. -template -struct PackedMatrix { - // The row/column sums needed for quantized matrix multiplication when - // the opposite operand of the multiplication uses a non-symmetric zero - // point. - // This member is only relevant for packed matrices. - // Additionally, Ruy always uses 32-bit signed accumulators for quantized - // matrix multiplication. - // For floating point types, there is no quantization, so this pointer - // will always be null. We still need code referencing it to compile - // though, even if it is always branched around. Hence we use Scalar* - // itself as the type in that case. - using SumsType = - typename std::conditional::value, Scalar, - std::int32_t>::type; - - Scalar* data = nullptr; - SumsType* sums = nullptr; - PackedLayout layout; - std::int32_t zero_point = 0; -}; - -template -DMatrix ToDMatrix(const Matrix& matrix) { - DMatrix ret; - ret.data_type = Type::Create(); - ret.data = ToVoidPtr(matrix.data.get()); - ret.layout = matrix.layout; - ret.zero_point = matrix.zero_point; - return ret; -} - -template -Matrix ToMatrix(const DMatrix& dmatrix) { - dmatrix.data_type.AssertIs(); - Matrix ret; - ret.data = static_cast(dmatrix.data); - ret.layout = dmatrix.layout; - ret.zero_point = dmatrix.zero_point; - return ret; -} - -template -PackedMatrix ToPackedMatrix(const PMatrix& pmatrix) { - using SumsType = typename PackedMatrix::SumsType; - pmatrix.data_type.AssertIs(); - pmatrix.sums_type.AssertIs(); - PackedMatrix ret; - ret.data = static_cast(pmatrix.data); - ret.sums = static_cast(pmatrix.sums); - ret.layout = pmatrix.layout; - ret.zero_point = pmatrix.zero_point; - return ret; -} - -// Helpers for Layout / PackedLayout. - -inline bool IsPacked(const Layout& layout) { - if (layout.order == Order::kColMajor) { - return layout.stride == layout.rows; - } else { - return layout.stride == layout.cols; - } -} - -inline bool IsRowMajor(const Layout& layout) { - return layout.order == Order::kRowMajor; -} - -template -inline bool IsColMajor(const LayoutOrPackedLayout& layout) { - return layout.order == Order::kColMajor; -} - -template -inline int FlatSize(const LayoutOrPackedLayout& layout) { - const int outerdim = - layout.order == Order::kColMajor ? layout.cols : layout.rows; - return layout.stride * outerdim; -} - -// TODO(b/130417400) add a unit test -inline int Offset(const Layout& layout, int row, int col) { - // TODO(benoitjacob) - should check this but this make the _slow tests take - // 5x longer. Find a mitigation like in Eigen with an 'internal' variant - // bypassing the check? - // RUY_DCHECK_GE(row, 0); - // RUY_DCHECK_GE(col, 0); - // RUY_DCHECK_LT(row, layout.rows); - // RUY_DCHECK_LT(col, layout.cols); - int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride; - int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride; - return row * row_stride + col * col_stride; -} - -// TODO(b/130417400) add a unit test -inline int Offset(const PackedLayout& layout, int row, int col) { - RUY_DCHECK(is_pot(layout.kernel.rows)); - RUY_DCHECK(is_pot(layout.kernel.cols)); - int row_outer = row & ~(layout.kernel.rows - 1); - int col_outer = col & ~(layout.kernel.cols - 1); - int row_stride_outer = - layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride; - int col_stride_outer = - layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride; - int offset_outer = - row_outer * row_stride_outer + col_outer * col_stride_outer; - int row_inner = row - row_outer; - int col_inner = col - col_outer; - int row_stride_inner = - layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols; - int col_stride_inner = - layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows; - int offset_inner = - row_inner * row_stride_inner + col_inner * col_stride_inner; - return offset_outer + offset_inner; -} - -// Helpers for Matrix. - -template -const Scalar* ElementPtr(const Matrix& mat, int row, int col) { - return mat.data.get() + Offset(mat.layout, row, col); -} - -template -Scalar* ElementPtr(Matrix* mat, int row, int col) { - return mat->data.get() + Offset(mat->layout, row, col); -} - -template -Scalar Element(const Matrix& mat, int row, int col) { - return *ElementPtr(mat, row, col); -} - -// Helpers for PackedMatrix. -// Duplicated from Matrix, but the duplication seems acceptable. - -template -const Scalar* ElementPtr(const PackedMatrix& mat, int row, int col) { - return mat.data + Offset(mat.layout, row, col); -} - -template -Scalar* ElementPtr(PackedMatrix* mat, int row, int col) { - return mat->data + Offset(mat->layout, row, col); -} - -template -Scalar Element(const PackedMatrix& mat, int row, int col) { - return *ElementPtr(mat, row, col); -} - -// Helpers for PMatrix. - -inline std::size_t DataSize(const PMatrix& packed) { - return FlatSize(packed.layout) * packed.data_type.size; -} - -inline std::size_t SumsSize(const PMatrix& packed) { - // Packed matrices are only relevant for Ruy's TrMul implementations. For - // TrMul, the number of sums is always equal to the number of columns. - return packed.layout.cols * packed.sums_type.size; -} - -// Transpose helpers. - -inline void Transpose(Order* order) { - *order = *order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor; -} - -inline void Transpose(Layout* layout) { - Transpose(&layout->order); - std::swap(layout->rows, layout->cols); -} - -template -inline void Transpose(Matrix* matrix) { - Transpose(&matrix->layout); -} - -// Helpers for KernelLayout. - -template -KernelLayout ToKernelLayout() { - KernelLayout ret; - ret.order = FixedKernelLayout::kOrder; - ret.rows = FixedKernelLayout::kRows; - ret.cols = FixedKernelLayout::kCols; - return ret; -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel.h b/tensorflow/lite/experimental/ruy/ruy/kernel.h deleted file mode 100644 index dd9a60b8d09..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -// IWYU pragma: begin_exports -#if RUY_PLATFORM(NEON) -#include "tensorflow/lite/experimental/ruy/ruy/kernel_arm.h" -#elif RUY_PLATFORM(X86) -#include "tensorflow/lite/experimental/ruy/ruy/kernel_x86.h" -#else -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#endif -// IWYU pragma: end_exports - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h b/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h deleted file mode 100644 index 760f0f0b4b5..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_arm.h +++ /dev/null @@ -1,211 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params); -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params); -#elif RUY_PLATFORM(NEON_32) -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params); -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params); -#endif -void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params); -void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params); -void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params); -void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params); - -#if RUY_PLATFORM(NEON_64) -template -struct Kernel> { - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - Tuning tuning = Tuning::kAuto; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonOutOfOrder1Col(params); - return; - } - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Kernel8bitNeonInOrder(params); - } else { - Kernel8bitNeonOutOfOrder(params); - } - } -}; -#endif - -#if RUY_PLATFORM(NEON_32) -template -struct Kernel> { - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - Tuning tuning = Tuning::kAuto; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonOutOfOrder1Col(params); - return; - } - Kernel8bitNeonOutOfOrder(params); - } -}; -#endif - -#if RUY_PLATFORM(NEON_64) -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonDotprodOutOfOrder1Col(params); - } else if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Kernel8bitNeonDotprodInOrder(params); - } else { - Kernel8bitNeonDotprodOutOfOrder(params); - } - } -}; -#endif - -void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params); -void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params); -void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params); -void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params); - -#if RUY_PLATFORM(NEON_64) -// A Float kernel for ARM64 Neon. -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - KernelFloatNeonInOrder(params); - } else { - KernelFloatNeonOutOfOrder(params); - } - } -}; -#endif - -#if RUY_PLATFORM(NEON_32) -// A Float kernel for ARM32 Neon. -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat<8, 4> params; - - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - - KernelFloat32NeonOutOfOrder(params); - } -}; -#endif - -// While the dotprod NEON extension does not concern floating-point arithmetic, -// its presence allows us to distinguish, in the in-order tuning case, between -// A53 and A55r1. TODO: should this be folded into tuning? -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - using Base = - Kernel>; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - KernelFloatNeonDotprodInOrder(params); - } else { - KernelFloatNeonOutOfOrder(params); - } - } -}; - -#endif // RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_arm32.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_arm32.cc deleted file mode 100644 index 673f2616f02..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_arm32.cc +++ /dev/null @@ -1,2499 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_ASM_LABEL_STORE_UINT8 91 -#define RUY_ASM_LABEL_STORE_INT8 92 -#define RUY_ASM_LABEL_STORE_INT16 93 -#define RUY_ASM_LABEL_STORE_INT32 94 -#define RUY_ASM_LABEL_AFTER_STORE 99 - -#define RUY_OFFSET_LHS_BASE_PTR 0 -#define RUY_OFFSET_RHS_BASE_PTR 4 -#define RUY_OFFSET_DST_BASE_PTR 8 -#define RUY_OFFSET_BIAS 12 -#define RUY_OFFSET_START_ROW 16 -#define RUY_OFFSET_START_COL 20 -#define RUY_OFFSET_LAST_ROW 24 -#define RUY_OFFSET_LAST_COL 28 -#define RUY_OFFSET_DST_ROWS 32 -#define RUY_OFFSET_DST_COLS 36 -#define RUY_OFFSET_LHS_STRIDE 40 -#define RUY_OFFSET_RHS_STRIDE 44 -#define RUY_OFFSET_DST_STRIDE 48 -#define RUY_OFFSET_DEPTH 52 -#define RUY_OFFSET_CLAMP_MIN 56 -#define RUY_OFFSET_CLAMP_MAX 60 -#define RUY_OFFSET_FLAGS 64 - -#define RUY_STACK_OFFSET_SIZE 96 -#define RUY_STACK_OFFSET_DST_COL_PTR 0 -#define RUY_STACK_OFFSET_DST_PTR 16 -#define RUY_STACK_OFFSET_ROW 32 -#define RUY_STACK_OFFSET_COL 48 -#define RUY_STACK_OFFSET_LHS_COL_PTR 64 -#define RUY_STACK_OFFSET_RHS_COL_PTR 80 - -template -void CheckOffsetsInKernelParamsFloat32(const Params&) { - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); - static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); -} - -// Float kernel for ARM32 out-of-order cores. -// Just like Float 64 version, except accumulate in to 8x4 block to only -// use 16 128-bit NEON registers. This is a "first pass" kernel and not -// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. -void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params) { - CheckOffsetsInKernelParamsFloat32(params); - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - const float* lhs_ptr = params.lhs_base_ptr; - const float* rhs_ptr = params.rhs_base_ptr; - // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are - // each composed of two 64-bit "d" registers. The asm kernel below has the - // following NEON register allocation: - // Registers q3 -- q10 are accumulators. During accumulation, - // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 - // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block - // of RHS, like this: - - // Register layout in "q" registers: - // RHS 1x4 block - // /--------------------------\ - // |q2.s[0] ... q2.s[3] | - // \--------------------------/ - // LHS 8x1 block - // /---------------------\ /--------------------- \ - // | q0.s[0] | | q3.s[0] ... q9.s[0] | - // | ... | | ... ... | - // | q0.s[3] | | q3.s[3] q9.s[3] | - // | q1.s[0] | | q4.s[0] q10.s[0] | - // | ... | | ... ... ... | - // | q1.s[3] | | q4.s[3] .. q10.s[3] | - // \---------------------/ \--------------------------/ - // accumulators 8x4 block - // q11, q14, q15 currently unused. q12 and q13 are used to load - // parameters used for the post-accumulation part of the kernel. - // For completeness, here is the register layout in "d" registers: - // RHS 1x4 block - // /--------------------------\ - // |d4[0] ... d5[1] | - // \--------------------------/ - // LHS 8x1 block - // /---------------------\ /--------------------------\ - // | d0[0] | | d6[0] ... d18[0] | - // | ... | | ... ... | - // | d1[1] | | d7[1] d19[1] | - // | d2[0] | | d8[0] d20[0] | - // | ... | | ... ... ... | - // | d3[1] | | d9[1] ... d21[1] | - // \---------------------/ \--------------------------/ - // accumulators 8x4 block - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" - - // clang-format off - - // Load the first 32 bytes of LHS and RHS data. - // Load q0, q1 - "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - // Load q2 - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - // Clear accumulators. - RUY_MAKE_ZERO(q3) - RUY_MAKE_ZERO(q4) - RUY_MAKE_ZERO(q5) - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov r1, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Accumulation loop - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r2\n" - "beq 79f\n" - - "2:\n" - - "vmla.f32 q3, q0, d4[0]\n" - "vmla.f32 q5, q0, d4[1]\n" - "vmla.f32 q7, q0, d5[0]\n" - "vmla.f32 q9, q0, d5[1]\n" - "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS - - "vmla.f32 q4, q1, d4[0]\n" - "vmla.f32 q6, q1, d4[1]\n" - "vmla.f32 q8, q1, d5[0]\n" - "vmla.f32 q10, q1, d5[1]\n" - "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - "add r1, r1, #1\n" - "cmp r1, r2\n" - - "blt 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "vmla.f32 q3, q0, d4[0]\n" - "vmla.f32 q5, q0, d4[1]\n" - "vmla.f32 q7, q0, d5[0]\n" - "vmla.f32 q9, q0, d5[1]\n" - - "vmla.f32 q4, q1, d4[0]\n" - "vmla.f32 q6, q1, d4[1]\n" - "vmla.f32 q8, q1, d5[0]\n" - "vmla.f32 q10, q1, d5[1]\n" - - // End of accumulation. The registers q3 -- q10 contain the final - // float32 accumulator values of the current 8x8 destination block. - // We now have to compute the final values from these accumulators - // and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #3\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #2\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Load some parameters needed for the end work on current block. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 8 bias values. - "vld1.32 {d24, d25, d26, d27}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore - // in the rest of the work on the current block. - // Load q0, q1 - "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - // Load q2 - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.f32 q3, q3, q12\n" - "vadd.f32 q4, q4, q13\n" - "vadd.f32 q5, q5, q12\n" - "vadd.f32 q6, q6, q13\n" - "vadd.f32 q7, q7, q12\n" - "vadd.f32 q8, q8, q13\n" - "vadd.f32 q9, q9, q12\n" - "vadd.f32 q10, q10, q13\n" - - // Load the clamp_min, clamp_max bounds - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.32 q12, r2\n" // clamp_min - "vdup.32 q13, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.f32 q3, q3, q12\n" - "vmax.f32 q4, q4, q12\n" - "vmax.f32 q5, q5, q12\n" - "vmax.f32 q6, q6, q12\n" - "vmax.f32 q7, q7, q12\n" - "vmax.f32 q8, q8, q12\n" - "vmax.f32 q9, q9, q12\n" - "vmax.f32 q10, q10, q12\n" - - // Apply the clamp_max bound - "vmin.f32 q3, q3, q13\n" - "vmin.f32 q4, q4, q13\n" - "vmin.f32 q5, q5, q13\n" - "vmin.f32 q6, q6, q13\n" - "vmin.f32 q7, q7, q13\n" - "vmin.f32 q8, q8, q13\n" - "vmin.f32 q9, q9, q13\n" - "vmin.f32 q10, q10, q13\n" - - // Compute how much of the 8x4 block of destination values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x4, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #8\n" - "mov r5, #4\n" - "cmp r1, #8\n" - // Compute r1 = how many rows of the 8x4 block fit - "it gt\n" - "movgt r1, r3\n" - "cmp r2, #4\n" - // Compute r2 = how many cols of the 8x4 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - // Yes, all of the 8x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x4 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x4 block fits. - // Set (r3 address, r4 stride) to write directly to destination matrix. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - "31:\n" - - // Write our float values to the destination described by - // (r3 address, r4 stride) - "vst1.32 {d6, d7, d8, d9}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q3) - RUY_MAKE_ZERO(q4) - "vst1.32 {d10, d11, d12, d13}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q5) - RUY_MAKE_ZERO(q6) - "vst1.32 {d14, d15, d16, d17}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - "vst1.32 {d18, d19, d20, d21}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - - // If all of the 8x4 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r6, #0\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #32\n" - "add r4, r4, r8\n" - // r2 = how many cols of the 8x4 block fit - "cmp r6, r2\n" - "blt 50b\n" - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #32\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used r3, r5, r10 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #8\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) - "add r1, r1, r8, lsl #2\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov r1, #1\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13"); -} - -#undef RUY_MAKE_ZERO -#undef RUY_STACK_OFFSET_SIZE -#undef RUY_STACK_OFFSET_DST_COL_PTR -#undef RUY_STACK_OFFSET_DST_PTR -#undef RUY_STACK_OFFSET_ROW -#undef RUY_STACK_OFFSET_COL -#undef RUY_STACK_OFFSET_LHS_COL_PTR -#undef RUY_STACK_OFFSET_RHS_COL_PTR - -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS - -#define RUY_OFFSET_BIAS 0 -#define RUY_OFFSET_LHS_SUMS 4 -#define RUY_OFFSET_RHS_SUMS 8 -#define RUY_OFFSET_LHS_BASE_PTR 12 -#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 -#define RUY_OFFSET_MULTIPLIER_EXPONENT 20 -#define RUY_OFFSET_RHS_BASE_PTR 24 -#define RUY_OFFSET_DST_BASE_PTR 28 -#define RUY_OFFSET_LHS_ZERO_POINT 32 -#define RUY_OFFSET_RHS_ZERO_POINT 36 -#define RUY_OFFSET_DST_ZERO_POINT 40 -#define RUY_OFFSET_PROD_ZP_DEPTH 44 -#define RUY_OFFSET_START_ROW 48 -#define RUY_OFFSET_START_COL 52 -#define RUY_OFFSET_LAST_ROW 56 -#define RUY_OFFSET_LAST_COL 60 -#define RUY_OFFSET_DST_ROWS 64 -#define RUY_OFFSET_DST_COLS 68 -#define RUY_OFFSET_LHS_STRIDE 72 -#define RUY_OFFSET_RHS_STRIDE 76 -#define RUY_OFFSET_DST_STRIDE 80 -#define RUY_OFFSET_DEPTH 84 -#define RUY_OFFSET_CLAMP_MIN 88 -#define RUY_OFFSET_CLAMP_MAX 92 -#define RUY_OFFSET_FLAGS 96 -#define RUY_OFFSET_DST_TYPE_ID 97 - -#define RUY_STACK_OFFSET_SIZE 96 -#define RUY_STACK_OFFSET_DST_COL_PTR 0 -#define RUY_STACK_OFFSET_DST_PTR 16 -#define RUY_STACK_OFFSET_ROW 32 -#define RUY_STACK_OFFSET_COL 48 -#define RUY_STACK_OFFSET_LHS_COL_PTR 64 -#define RUY_STACK_OFFSET_RHS_COL_PTR 80 - -template -void CheckOffsetsInKernelParams8bit(const Params&) { - static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, - ""); - static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, - ""); - static_assert(offsetof(Params, multiplier_fixedpoint) == - RUY_OFFSET_MULTIPLIER_FIXEDPOINT, - ""); - static_assert( - offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, - ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); - static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); -} - -// Fast-int8 kernel, ported from ARM 64 version. -// Relevant target CPUs for this kernel include Krait 400 and A9, -// since these are 32-bit, out-of-order CPUs. -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - - // The asm kernel below has the following NEON register allocation: - // - // q6 - q13 are 128-bit (4x32b) accumulators. - // During accumulation, d0 -- d7 are used to load int8 data from LHS and - // d8 -- d11 from RHS: - // int8 RHS 16x2 block - // /-----------------------------\ - // |d8.b[0-7] ..... d10.b[0-7]| - // | ... ... | - // |d9.b[0-7] ..... d11.b[0-7]| - // \-----------------------------/ - // int8 LHS 4x16 block - // /------------------------\ /-----------------------------\ - // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | - // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | - // (Reload d0, d1, d2, d3) - // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | - // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | - // \------------------------/ \-----------------------------/ - // 128-bit accumulators 4x2 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" - - // clang-format off - - // Load the first 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - // Clear accumulators. - RUY_MAKE_ZERO(q6) - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - RUY_MAKE_ZERO(q11) - "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - RUY_MAKE_ZERO(q12) - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - RUY_MAKE_ZERO(q13) - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(q14) - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - RUY_MAKE_ZERO(q15) - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // r1 is how many levels of depth we have already loaded - // data for, r10 is the total depth. - "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r10\n" - "beq 79f\n" - - "2:\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q6, q7, q10, q11 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - "vpadal.s16 q10, q2\n" - "vpadal.s16 q11, q3\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q8, q9, q12, q13 - "vpadal.s16 q8, q14\n" - "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" - "vpadal.s16 q9, q15\n" - "vpadal.s16 q12, q2\n" - "vpadal.s16 q13, q3\n" - - // Prefetch the next 64 bytes of LHS and RHS data. - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Each iteration of this loop advances by 16 levels of depth. - "add r1, r1, #16\n" - - // Loop termination condition - "cmp r1, r10\n" - - "blt 2b\n" - - "79:\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q6, q7, q10, q11 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - "vpadal.s16 q10, q2\n" - "vpadal.s16 q11, q3\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - - // Then pairwise accumulate in to q8, q9, q12, q13 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - "vpadal.s16 q12, q2\n" - "vpadal.s16 q13, q3\n" - - - // All accumulation over depth done. q6 - q13 contain the 4x32b - // accumulators for the 4x2 final matrix. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x2 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // q6-q13 now contain 4 x 32b - "vpadd.i32 d0, d12, d13\n" - "vpadd.i32 d1, d14, d15\n" - "vpadd.i32 d2, d16, d17\n" - "vpadd.i32 d3, d18, d19\n" - "vpadd.i32 d4, d20, d21\n" - "vpadd.i32 d5, d22, d23\n" - "vpadd.i32 d6, d24, d25\n" - "vpadd.i32 d7, d26, d27\n" - - // d0-d7 each contain 2 x 32b accumulators. - // Need to add pairwise to get 1 x 32b for each of the 4x2 entries - // of destination, (Four 'd' registers total) - "vpadd.i32 d28, d0, d1\n" - "vpadd.i32 d29, d2, d3\n" - "vpadd.i32 d30, d4, d5\n" - "vpadd.i32 d31, d6, d7\n" - - //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #1\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 4 bias values. - "vld1.32 {d24, d25}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Add to the bias values the product - // (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in - // https://arxiv.org/pdf/1712.05877.pdf - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "vdup.32 q9, r3\n" - "vadd.i32 q12, q12, q9\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.i32 q14, q14, q12\n" - "vadd.i32 q15, q15, q12\n" - - // LHS/RHS zero points - // Has RHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - // Offset by current col * number of bytes per value - "add r3, r3, r4, lsl #2\n" - "vld1.32 { d12 }, [r3]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "vdup.32 q10, r5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vmls.i32 q14, q10, d12[0]\n" - "vmls.i32 q15, q10, d12[1]\n" - "401:\n" - - // Has LHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Offset by current row * number of bytes per value - "add r2, r2, r4, lsl #2\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - - // Load 4 lhs_sums values. - "vld1.32 {d22, d23}, [r2]\n" - "vdup.32 d13, r5\n" // rhs_zero_point - - // Compute lhs_sums * rhs_zero_point. - "vmul.i32 q11, q11, d13[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vsub.s32 q14, q14, q11\n" - "vsub.s32 q15, q15, q11\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r4, lsl #2\n" - "it ne\n" - "movne r1, r5\n" - - "vld1.32 {q10}, [r1]\n" - - RUY_MAKE_ZERO(q8) - "vmax.s32 q12, q10, q8\n" - - "vshl.s32 q14, q14, q12\n" - "vshl.s32 q15, q15, q12\n" - - "vmin.s32 q12, q10, q8\n" - - // Load fixed point part of the multiplier - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - // r6 has flags, r4 has row - "add r5, r1, r4, lsl #2\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "it ne\n" - "movne r1, r5\n" - "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint - - // Apply the fixed-point part of the multiplier. - "vqrdmulh.s32 q14, q14, q10\n" - "vqrdmulh.s32 q15, q15, q10\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "vand q8, q14, q12\n" - "vand q9, q15, q12\n" - "vshr.s32 q8, q8, #31\n" - "vshr.s32 q9, q9, #31\n" - "vqadd.s32 q14, q14, q8\n" - "vqadd.s34 q15, q15, q9\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "vrshl.s32 q14, q14, q12\n" - "vrshl.s32 q15, q15, q12\n" - - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - // Store uint8 values: - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, d12 -- d26, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to uint8 - // Now all 8 1-byte values are in d30. - "vqmovun.s16 d30, q14\n" - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.u8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.u8 d30, d30, d29\n" - - // Compute how much of the 4x2 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #4\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.32 {d30[0]}, [r3]\n" - "add r4, r4, r5\n" - "mov r3, r4\n" - "vst1.32 {d30[1]}, [r3]\n" - - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - // Store int8 values: - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, d12 -- d26, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to int8 - // Now all 8 1-byte values are in d30. - "vqmovn.s16 d30, q14\n" - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.s8 d30, d30, d29\n" - - // Compute how much of the 4x2 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #4\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.32 {d30[0]}, [r3]\n" - "add r4, r4, r5\n" - "mov r3, r4\n" - "vst1.32 {d30[1]}, [r3]\n" - - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Load the destination zero point into each of the 4 32-bit slots - // in a q register. - "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.32 q13, r4\n" // dst_zero_point - // Add the destination zero point - "vadd.s32 q14, q14, q13\n" - "vadd.s32 q15, q15, q13\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q15) - - // Load the clamp_min, clamp_max bounds - "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.16 q12, r2\n" // clamp_min - "vdup.16 q13, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s16 q14, q14, q12\n" - // Apply the clamp_max bound - "vmin.s16 q14, q14, q13\n" - - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x2 block of destination 16-bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.16 {q14}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - // Shift of offset register for half-word loads not allowed in A32, - // so we shift, load/store, then shift back r8. - "lsl r8, r8, #1\n" - "ldrh r10, [r3, r8]\n" - "strh r10, [r4, r8]\n" - "lsr r8, r8, #1\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #8\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #2\n" - - "vst1.16 {d28[0]}, [r3], r6\n" - "add r4, r4, r5\n" - "vst1.16 {d28[1]}, [r3], r6\n" - "vst1.16 {d28[2]}, [r3], r6\n" - "vst1.16 {d28[3]}, [r3], r6\n" - "mov r3, r4\n" - "vst1.16 {d29[0]}, [r3], r6\n" - "vst1.16 {d29[1]}, [r3], r6\n" - "vst1.16 {d29[2]}, [r3], r6\n" - "vst1.16 {d29[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #8\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q14) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x2 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #16\n" - "b 31f\n" - - "30:\n" - // Yes, all of the 4x2 block fits. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // r3 address, r4 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - - "31:\n" - - "vst1.32 {d28, d29}, [r3]\n" - "add r3, r3, r4\n" - "vst1.32 {d30, d31}, [r3]\n" - - // If all of the 4x2 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 4x2 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r6, #0\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #16\n" - "add r4, r4, r8\n" - // r2 = how many cols of the 8x4 block fit - "cmp r6, r2\n" - "blt 50b\n" - - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #16\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #4\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) - "add r1, r1, r8, lsl #1\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13", "q14", "q15"); -} - -// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS -// is still packed as if it has two columns -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - - // The asm kernel below has the following NEON register allocation: - // - // q6 - q13 are 128-bit (4x32b) accumulators. - // During accumulation, d0 -- d7 are used to load int8 data from LHS and - // d8 -- d11 from RHS: - // int8 RHS 16x1 block - // /------------\ - // | d8.b[0] | - // | ... | - // | d8.b[7] | - // | d9.b[0] | - // | ... | - // | d9.b[7] | - // \------------/ - // int8 LHS 4x16 block - // /-----------------------------------------\ /------------\ - // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | - // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | - // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | - // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | - // \-----------------------------------------/ \------------/ - // 128-bit accumulators 4x1 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" - - // clang-format off - - // Load the first 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // r1 is how many levels of depth we have already loaded - // data for, r10 is the total depth. - "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r10\n" - "beq 79f\n" - - "2:\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q15, d2, d8\n" - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q15, d3, d9\n" - - // Then pairwise accumulate in to q6, q7 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d4, d8\n" - "vmull.s8 q15, d6, d8\n" - "vmlal.s8 q14, d5, d9\n" - "vmlal.s8 q15, d7, d9\n" - - // Then pairwise accumulate in to q8, q9 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - - - // Load the next 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Each iteration of this loop advances by 16 levels of depth. - "add r1, r1, #16\n" - - // Loop termination condition - "cmp r1, r10\n" - - "blt 2b\n" - - "79:\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q15, d2, d8\n" - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q15, d3, d9\n" - - // Then pairwise accumulate in to q6, q7 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d4, d8\n" - "vmull.s8 q15, d6, d8\n" - "vmlal.s8 q14, d5, d9\n" - "vmlal.s8 q15, d7, d9\n" - - // Then pairwise accumulate in to q8, q9 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - - // All accumulation over depth done. q6 - q9 contain the 4x32b - // accumulators for the 4x1 final matrix. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x2 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // q6-q9 now contain 4 x 32b - "vpadd.i32 d0, d12, d13\n" - "vpadd.i32 d1, d14, d15\n" - "vpadd.i32 d2, d16, d17\n" - "vpadd.i32 d3, d18, d19\n" - - // d0-d4 each contain 2 x 32b accumulators. - // Need to add pairwise to get 1 x 32b for each of the 4x1 entries - // of destination, (Four 'd' registers total) - "vpadd.i32 d28, d0, d1\n" - "vpadd.i32 d29, d2, d3\n" - - // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #1\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 4 bias values. - "vld1.32 {d24, d25}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Add to the bias values the product - // (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in - // https://arxiv.org/pdf/1712.05877.pdf - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "vdup.32 q9, r3\n" - "vadd.i32 q12, q12, q9\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.i32 q14, q14, q12\n" - - // LHS/RHS zero points - // Has RHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - // Offset by current col * number of bytes per value - "add r3, r3, r4, lsl #2\n" - "vld1.32 { d12 }, [r3]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "vdup.32 q10, r5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vmls.i32 q14, q10, d12[0]\n" - "401:\n" - - // Has LHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Offset by current row * number of bytes per value - "add r2, r2, r4, lsl #2\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - - // Load 4 lhs_sums values. - "vld1.32 {d22, d23}, [r2]\n" - "vdup.32 d13, r5\n" // rhs_zero_point - - // Compute lhs_sums * rhs_zero_point. - "vmul.i32 q11, q11, d13[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vsub.s32 q14, q14, q11\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r4, lsl #2\n" - "it ne\n" - "movne r1, r5\n" - - "vld1.32 {q10}, [r1]\n" - - RUY_MAKE_ZERO(q8) - "vmax.s32 q12, q10, q8\n" - - "vshl.s32 q14, q14, q12\n" - - "vmin.s32 q12, q10, q8\n" - - // Load fixed point part of the multiplier - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - // r6 has flags, r4 has row - "add r5, r1, r4, lsl #2\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "it ne\n" - "movne r1, r5\n" - "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint - - // Apply the fixed-point part of the multiplier. - "vqrdmulh.s32 q14, q14, q10\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "vand q8, q14, q12\n" - "vshr.s32 q8, q8, #31\n" - "vqadd.s32 q14, q14, q8\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "vrshl.s32 q14, q14, q12\n" - - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - // Store uint8 values: - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to uint8 - "vqmovun.s16 d30, q14\n" - // At this point, we only need 4 8-bit values in the lower half - // of d30. - - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.u8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.u8 d30, d30, d29\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x1 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.8 {d30[0]}, [r3], r6\n" - "vst1.8 {d30[1]}, [r3], r6\n" - "vst1.8 {d30[2]}, [r3], r6\n" - "vst1.8 {d30[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - // Store int8 values: - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to int8 - "vqmovn.s16 d30, q14\n" - // At this point, we only need 4 8-bit values in the lower half - // of d30. - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.s8 d30, d30, d29\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4 i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.8 {d30[0]}, [r3], r6\n" - "vst1.8 {d30[1]}, [r3], r6\n" - "vst1.8 {d30[2]}, [r3], r6\n" - "vst1.8 {d30[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Load the destination zero point into each of the 4 32-bit slots - // in a q register. - "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.32 q13, r4\n" // dst_zero_point - // Add the destination zero point - "vadd.s32 q14, q14, q13\n" - //"vadd.s32 q15, q15, q13\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q15) - - // Load the clamp_min, clamp_max bounds - "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.16 d24, r2\n" // clamp_min - "vdup.16 d26, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s16 d28, d28, d24\n" - // Apply the clamp_max bound - "vmin.s16 d28, d28, d26\n" - - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x1 block of destination 16-bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x1 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.16 {d28}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - // Shift of offset register for half-word loads not allowed in A32, - // so we shift, load/store, then shift back r8. - "lsl r8, r8, #1\n" - "ldrh r10, [r3, r8]\n" - "strh r10, [r4, r8]\n" - "lsr r8, r8, #1\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #2\n" - - "vst1.16 {d28[0]}, [r3], r6\n" - "vst1.16 {d28[1]}, [r3], r6\n" - "vst1.16 {d28[2]}, [r3], r6\n" - "vst1.16 {d28[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #8\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q14) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x1 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #16\n" - "b 31f\n" - - "30:\n" - // Yes, all of the 4x1 block fits. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // r3 address, r4 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - - "31:\n" - - "vst1.32 {d28, d29}, [r3]\n" - - // If all of the 4x1 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 4x1 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #16\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #4\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by dst_stride (i.e. 1 column) - "add r1, r1, r8\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13", "q14", "q15"); -} - -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_LHS_SUMS -#undef RUY_OFFSET_RHS_SUMS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT -#undef RUY_OFFSET_MULTIPLIER_EXPONENT -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_LHS_ZERO_POINT -#undef RUY_OFFSET_RHS_ZERO_POINT -#undef RUY_OFFSET_DST_ZERO_POINT -#undef RUY_OFFSET_PROD_ZP_DEPTH -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS -#undef RUY_OFFSET_DST_TYPE_ID - -#undef RUY_STACK_OFFSET_SIZE -#undef RUY_STACK_OFFSET_DST_COL_PTR -#undef RUY_STACK_OFFSET_DST_PTR -#undef RUY_STACK_OFFSET_ROW -#undef RUY_STACK_OFFSET_COL -#undef RUY_STACK_OFFSET_LHS_COL_PTR -#undef RUY_STACK_OFFSET_RHS_COL_PTR - -#endif // RUY_PLATFORM(NEON_32) && (RUY_OPT_ENABLED(RUY_OPT_ASM) -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_arm64.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_arm64.cc deleted file mode 100644 index eff9d2c8a09..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_arm64.cc +++ /dev/null @@ -1,7835 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_ASM_LABEL_STORE_UINT8 91 -#define RUY_ASM_LABEL_STORE_INT8 92 -#define RUY_ASM_LABEL_STORE_INT16 93 -#define RUY_ASM_LABEL_STORE_INT32 94 -#define RUY_ASM_LABEL_AFTER_STORE 99 - -#define RUY_OFFSET_BIAS 0 -#define RUY_OFFSET_LHS_SUMS 8 -#define RUY_OFFSET_RHS_SUMS 16 -#define RUY_OFFSET_LHS_BASE_PTR 24 -#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32 -#define RUY_OFFSET_MULTIPLIER_EXPONENT 40 -#define RUY_OFFSET_RHS_BASE_PTR 48 -#define RUY_OFFSET_DST_BASE_PTR 56 -#define RUY_OFFSET_LHS_ZERO_POINT 64 -#define RUY_OFFSET_RHS_ZERO_POINT 68 -#define RUY_OFFSET_DST_ZERO_POINT 72 -#define RUY_OFFSET_PROD_ZP_DEPTH 76 -#define RUY_OFFSET_START_ROW 80 -#define RUY_OFFSET_START_COL 84 -#define RUY_OFFSET_LAST_ROW 88 -#define RUY_OFFSET_LAST_COL 92 -#define RUY_OFFSET_DST_ROWS 96 -#define RUY_OFFSET_DST_COLS 100 -#define RUY_OFFSET_LHS_STRIDE 104 -#define RUY_OFFSET_RHS_STRIDE 108 -#define RUY_OFFSET_DST_STRIDE 112 -#define RUY_OFFSET_DEPTH 116 -#define RUY_OFFSET_CLAMP_MIN 120 -#define RUY_OFFSET_CLAMP_MAX 124 -#define RUY_OFFSET_FLAGS 128 - -template -void CheckOffsetsInKernelParams8bit(const Params&) { - static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, - ""); - static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, - ""); - static_assert(offsetof(Params, multiplier_fixedpoint) == - RUY_OFFSET_MULTIPLIER_FIXEDPOINT, - ""); - static_assert( - offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, - ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); - static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); -} - -// Fast-int8-trick kernel, similar to this production gemmlowp kernel: -// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296 -// -// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, -// since these are 64-bit, out-of-order and without dotprod support. -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 -- v7 from RHS: - // - // int8 RHS 16x4 block - // /-----------------------------------------\ - // |v4.b[0] ... v7.b[0] | - // | ... ... | - // |v4.b[15] ... v7.b[15] | - // \-----------------------------------------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 4x4 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - - "sadalp v24.4s, v8.8h\n" - "smull v8.8h, v0.8b, v4.8b\n" - "sadalp v25.4s, v9.8h\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - "smull v9.8h, v1.8b, v4.8b\n" - "sadalp v26.4s, v10.8h\n" - "smull v10.8h, v2.8b, v4.8b\n" - "sadalp v27.4s, v11.8h\n" - "smull v11.8h, v3.8b, v4.8b\n" - "sadalp v28.4s, v12.8h\n" - "smull v12.8h, v0.8b, v5.8b\n" - "sadalp v29.4s, v13.8h\n" - "smull v13.8h, v1.8b, v5.8b\n" - "sadalp v30.4s, v14.8h\n" - "smull v14.8h, v2.8b, v5.8b\n" - "sadalp v31.4s, v15.8h\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - - // Loop termination condition - "cmp w1, w12\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - - "sadalp v24.4s, v8.8h\n" - "sadalp v25.4s, v9.8h\n" - "sadalp v26.4s, v10.8h\n" - "sadalp v27.4s, v11.8h\n" - "sadalp v28.4s, v12.8h\n" - "sadalp v29.4s, v13.8h\n" - "sadalp v30.4s, v14.8h\n" - "sadalp v31.4s, v15.8h\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 4x4 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x4 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - "addp v20.4s, v20.4s, v21.4s\n" - "addp v22.4s, v22.4s, v23.4s\n" - "addp v24.4s, v24.4s, v25.4s\n" - "addp v26.4s, v26.4s, v27.4s\n" - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - "addp v17.4s, v20.4s, v22.4s\n" - "addp v18.4s, v24.4s, v26.4s\n" - "addp v19.4s, v28.4s, v30.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v14.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v14.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[1]\n" - "mls v18.4s, v10.4s, v14.s[2]\n" - "mls v19.4s, v10.4s, v14.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v11.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v12.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "sqrdmulh v18.4s, v18.4s, v15.4s\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v12.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v12.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "sqxtun2 v16.16b, v17.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - - // Cast-and-saturate from int16 to int8 - "sqxtn v16.8b, v16.8h\n" - "sqxtn2 v16.16b, v17.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[5], [x3], #2\n" - "st1 {v16.h}[6], [x3], #2\n" - "st1 {v16.h}[7], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[1], [x3], #2\n" - "st1 {v17.h}[2], [x3], #2\n" - "st1 {v17.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[5], [x3], #2\n" - "st1 {v17.h}[6], [x3], #2\n" - "st1 {v17.h}[7], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - "str q18, [%[dst_tmp_buf], #32]\n" - "str q19, [%[dst_tmp_buf], #48]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v17.s}[1], [x3], #4\n" - "st1 {v17.s}[2], [x3], #4\n" - "st1 {v17.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v18.s}[1], [x3], #4\n" - "st1 {v18.s}[2], [x3], #4\n" - "st1 {v18.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v19.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v19.s}[1], [x3], #4\n" - "st1 {v19.s}[2], [x3], #4\n" - "st1 {v19.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of -// RHS cols == 1. -// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, -// since these are 64-bit, out-of-order and without dotprod support. -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v19 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 from RHS: - // - // int8 RHS 16x1 block - // /-----------\ - // |v4.b[0] | - // | ... | - // |v4.b[15] | - // \-----------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------\ - // |v0.b[0] ... v0.b[15] | |v16.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s | - // \---------------------/ \-----------/ - // int32 accumulators 4x1 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - "sadalp v17.4s, v9.8h\n" - "sadalp v18.4s, v10.8h\n" - "sadalp v19.4s, v11.8h\n" - - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - - // Loop termination condition - "cmp w1, w12\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "sadalp v17.4s, v9.8h\n" - "sadalp v18.4s, v10.8h\n" - "sadalp v19.4s, v11.8h\n" - - // End of accumulation. The registers v16 -- v19 contain the final - // int32 accumulator values of the current 4x1 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x1 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - // (still multiply column stride by 4 due to packing) - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - // (all four 32-bit accumulators are in v16 at this point) - "add v16.4s, v16.4s, v14.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this instruction, all data is in lower half (64-bits) of v16 - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - // Now all data is in the first 32-bits of v16 - "sqxtun v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x1 block fit - "csel w1, w1, w3, le\n" - - // Test if w1==4, i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in the lower half (64 bits) of v16. - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to int8 - "sqxtn v16.8b, v16.8h\n" - // At this point, we only need 4 lowest 8-bit values in v16. - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4, i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - // After this instruction, all data is in lower half of v16. - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4 i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19"); -} - -// Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs. -// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and -// the original Cortex-A55, since these are 64-bit and do not support dotprod. -// -// While this kernel does not have a direct equivalent in gemmlowp, it was -// developed based on insights that David Mansell at ARM shared with their -// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A53: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 -void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 -- v7 from RHS: - // - // int8 RHS 16x4 block - // /-----------------------------------------\ - // |v4.b[0] ... v7.b[0] | - // | ... ... | - // |v4.b[15] ... v7.b[15] | - // \-----------------------------------------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 4x4 block - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - RUY_MAKE_ZERO(v16) - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(v17) - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - RUY_MAKE_ZERO(v18) - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - RUY_MAKE_ZERO(v19) - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v20) - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v21) - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - RUY_MAKE_ZERO(v22) - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - RUY_MAKE_ZERO(v23) - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v24) - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v25) - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v26) - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v27) - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v28) - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v29) - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v30) - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v31) - - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ldr d4, [%[rhs_ptr], #0]\n" - "smull v8.8h, v0.8b, v6.8b\n" - "ldr x7, [%[rhs_ptr], #8]\n" - "sadalp v17.4s, v9.8h\n" - "ldr d5, [%[rhs_ptr], #16]\n" - "smull v9.8h, v1.8b, v6.8b\n" - "ldr x8, [%[rhs_ptr], #24]\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "smull v11.8h, v3.8b, v6.8b\n" - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "sadalp v20.4s, v12.8h\n" - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - "smull v12.8h, v0.8b, v7.8b\n" - // Loop termination condition - "cmp w1, w12\n" - "sadalp v21.4s, v13.8h\n" - "ldr x3, [%[lhs_ptr], #-56]\n" - "smull v13.8h, v1.8b, v7.8b\n" - "ldr x4, [%[lhs_ptr], #-40]\n" - "sadalp v22.4s, v14.8h\n" - "ldr x5, [%[lhs_ptr], #-24]\n" - "smull v14.8h, v2.8b, v7.8b\n" - "ldr x6, [%[lhs_ptr], #-8]\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "ldr x9, [%[rhs_ptr], #-24]\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - "ldr d6, [%[rhs_ptr], #-32]\n" - "smlal2 v12.8h, v0.16b, v7.16b\n" - "ldr d0, [%[lhs_ptr], #-64]\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "ldr d1, [%[lhs_ptr], #-48]\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "ins v4.d[1], x7\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - "ins v5.d[1], x8\n" - - "ldr d2, [%[lhs_ptr], #-32]\n" - "ins v0.d[1], x3\n" - "sadalp v24.4s, v8.8h\n" - "ldr d3, [%[lhs_ptr], #-16]\n" - "ins v1.d[1], x4\n" - "smull v8.8h, v0.8b, v4.8b\n" - "ins v2.d[1], x5\n" - "sadalp v25.4s, v9.8h\n" - "ins v3.d[1], x6\n" - "smull v9.8h, v1.8b, v4.8b\n" - "ldr d7, [%[rhs_ptr], #-16]\n" - "sadalp v26.4s, v10.8h\n" - "ldr x10, [%[rhs_ptr], #-8]\n" - "smull v10.8h, v2.8b, v4.8b\n" - "sadalp v27.4s, v11.8h\n" - "smull v11.8h, v3.8b, v4.8b\n" - "sadalp v28.4s, v12.8h\n" - "smull v12.8h, v0.8b, v5.8b\n" - "sadalp v29.4s, v13.8h\n" - "smull v13.8h, v1.8b, v5.8b\n" - "sadalp v30.4s, v14.8h\n" - "smull v14.8h, v2.8b, v5.8b\n" - "sadalp v31.4s, v15.8h\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "ins v6.d[1], x9\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "ins v7.d[1], x10\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - - "sadalp v24.4s, v8.8h\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "sadalp v25.4s, v9.8h\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "sadalp v26.4s, v10.8h\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "sadalp v27.4s, v11.8h\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "sadalp v28.4s, v12.8h\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "sadalp v29.4s, v13.8h\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "sadalp v30.4s, v14.8h\n" - "sadalp v31.4s, v15.8h\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 4x4 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x4 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - "addp v20.4s, v20.4s, v21.4s\n" - "addp v22.4s, v22.4s, v23.4s\n" - "addp v24.4s, v24.4s, v25.4s\n" - "addp v26.4s, v26.4s, v27.4s\n" - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - "addp v17.4s, v20.4s, v22.4s\n" - "addp v18.4s, v24.4s, v26.4s\n" - "addp v19.4s, v28.4s, v30.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "ldr d0, [%[lhs_ptr], #0]\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "ldr d1, [%[lhs_ptr], #16]\n" - "add v17.4s, v17.4s, v14.4s\n" - "ldr d2, [%[lhs_ptr], #32]\n" - "add v18.4s, v18.4s, v14.4s\n" - "ldr d3, [%[lhs_ptr], #48]\n" - "add v19.4s, v19.4s, v14.4s\n" - "ldr d4, [%[rhs_ptr], #0]\n" - "ldr d5, [%[rhs_ptr], #16]\n" - "ldr d6, [%[rhs_ptr], #32]\n" - "ldr d7, [%[rhs_ptr], #48]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[1]\n" - "mls v18.4s, v10.4s, v14.s[2]\n" - "mls v19.4s, v10.4s, v14.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v11.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - "ldr x1, [%[lhs_ptr], #8]\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - "ldr x2, [%[lhs_ptr], #24]\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "ldr x3, [%[lhs_ptr], #40]\n" - "sshl v18.4s, v18.4s, v12.4s\n" - "ldr x4, [%[lhs_ptr], #56]\n" - "sshl v19.4s, v19.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "ins v0.d[1], x1\n" - "ldr x1, [%[rhs_ptr], #8]\n" - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - "ins v1.d[1], x2\n" - "ldr x2, [%[rhs_ptr], #24]\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "ins v2.d[1], x3\n" - "ldr x3, [%[rhs_ptr], #40]\n" - "sqrdmulh v18.4s, v18.4s, v15.4s\n" - "ins v3.d[1], x4\n" - "ldr x4, [%[rhs_ptr], #56]\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v12.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v12.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "dup v14.8h, v13.h[4]\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" - RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v22) - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - RUY_MAKE_ZERO(v23) - "sqxtun2 v16.16b, v17.8h\n" - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.16b, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.16b, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "dup v14.8h, v13.h[4]\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" - RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v22) - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - RUY_MAKE_ZERO(v23) - "sqxtn2 v16.16b, v17.8h\n" - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.16b, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.16b, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - "add %[lhs_ptr], %[lhs_ptr], #64\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.8h, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.8h, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[5], [x3], #2\n" - "st1 {v16.h}[6], [x3], #2\n" - "st1 {v16.h}[7], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[1], [x3], #2\n" - "st1 {v17.h}[2], [x3], #2\n" - "st1 {v17.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[5], [x3], #2\n" - "st1 {v17.h}[6], [x3], #2\n" - "st1 {v17.h}[7], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - "ldr x1, [%[lhs_ptr], #8]\n" - "ldr x2, [%[lhs_ptr], #24]\n" - "ldr x3, [%[lhs_ptr], #40]\n" - "ldr x4, [%[lhs_ptr], #56]\n" - - "ins v0.d[1], x1\n" - "ldr x1, [%[rhs_ptr], #8]\n" - "ins v1.d[1], x2\n" - "ldr x2, [%[rhs_ptr], #24]\n" - "ins v2.d[1], x3\n" - "ldr x3, [%[rhs_ptr], #40]\n" - "ins v3.d[1], x4\n" - "ldr x4, [%[rhs_ptr], #56]\n" - "ins v4.d[1], x1\n" - "ins v5.d[1], x2\n" - "ins v6.d[1], x3\n" - "ins v7.d[1], x4\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - - RUY_MAKE_ZERO(v20) - "add %[lhs_ptr], %[lhs_ptr], #64\n" - RUY_MAKE_ZERO(v21) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - RUY_MAKE_ZERO(v22) - - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - "str q18, [%[dst_tmp_buf], #32]\n" - "str q19, [%[dst_tmp_buf], #48]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v17.s}[1], [x3], #4\n" - "st1 {v17.s}[2], [x3], #4\n" - "st1 {v17.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v18.s}[1], [x3], #4\n" - "st1 {v18.s}[2], [x3], #4\n" - "st1 {v18.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v19.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v19.s}[1], [x3], #4\n" - "st1 {v19.s}[2], [x3], #4\n" - "st1 {v19.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "smull v11.8h, v3.8b, v4.8b\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "smull v12.8h, v0.8b, v5.8b\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Kernel taking advantage of the optional dotprod instruction. -// This is very similar to (and directly inspired by) this gemmlowp kernel -// which was contributed by David Mansell at ARM: -// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391 -// -// Besides the ruy-ification, the main difference here is that we use a 8x8 -// instead of 12x8 width, so as to stick to power-of-two widths. This slightly -// narrower kernel layout is still wide enough to achieve high performance -// although we haven't actually performed a real comparison to know exactly -// how this compares to ARM's aforementioned kernel. -// -// Relevant target CPUs for this kernel include ARM Cortex-A76, -// since these are 64-bit, out-of-order and with dotprod support. -void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v15 are used to load int8 data from LHS and - // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and - // v3 are used to load a 4x8 block of RHS, like this: - // - // int8 RHS 4x8 block - // /-----------------------------------------\ - // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| - // | ... ... | - // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| - // \-----------------------------------------/ - // int8 LHS 8x4 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| - // | ... ... | | ... ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| - // | ... ... | | ... ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 8x8 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for loading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Optional, maximally-streaming, partial-unrolling (4x unrolled) - // optimization of the kernel inner loop (over depth). For more - // comments, see the non-unrolled loop below after the #endif. -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "cmp w12, #32\n" - "blt 78f\n" - - "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v8.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v9.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" - "mov w1, #16\n" - - "and w3, w12, #-16\n" - "81:\n" - "add w1, w1, #16\n" - - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ldr q0, [%[lhs_ptr], #0]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ldr q2, [%[rhs_ptr], #0]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ldr q1, [%[lhs_ptr], #16]\n" - - ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" - ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" - "ldr q3, [%[rhs_ptr], #16]\n" - ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" - ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" - ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" - ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" - ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" - ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" - ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" - ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" - ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" - ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" - "ldr q5, [%[lhs_ptr], #48]\n" - ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" - ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" - "ldr q7, [%[rhs_ptr], #48]\n" - ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" - ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" - "ldr q4, [%[lhs_ptr], #32]\n" - - ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" - ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" - "ldr q6, [%[rhs_ptr], #32]\n" - ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" - ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" - ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" - ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" - ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" - ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" - ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" - ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" - ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" - ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" - "ldr q9, [%[lhs_ptr], #80]\n" - ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" - ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" - "ldr q11, [%[rhs_ptr], #80]\n" - ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" - ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" - "ldr q8, [%[lhs_ptr], #64]\n" - - ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" - ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" - "ldr q10, [%[rhs_ptr], #64]\n" - ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" - ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" - "add %[lhs_ptr], %[lhs_ptr], #128\n" - ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" - ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" - "add %[rhs_ptr], %[rhs_ptr], #128\n" - ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" - ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" - ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" - ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" - "cmp w1, w3\n" - ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" - ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" - "ldr q13, [%[lhs_ptr], #-16]\n" - ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" - ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" - "ldr q15, [%[rhs_ptr], #-16]\n" - ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" - ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" - "ldr q12, [%[lhs_ptr], #-32]\n" - - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - "ldr q14, [%[rhs_ptr], #-32]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - "blt 81b\n" - - ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" - ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" - ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" - ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" - ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" - ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" - ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" - ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" - ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" - ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" - ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" - ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" - ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" - ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" - ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" - ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" - - ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" - ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" - ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" - ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" - ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" - ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" - ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" - ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" - ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" - ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" - ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" - ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" - ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" - ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" - ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" - ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" - - ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" - ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" - ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" - ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" - ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" - ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" - ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" - ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" - ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" - ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" - ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" - ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" - ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" - ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" - ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" - ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" - - "78:\n" - -#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - - // Ordinary kernel inner loop (over depth), the simpler loop that the - // above was an equivalent 4x-partially-unrolled version of. - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Because of the data that we have already loaded, we can start the - // loop body right away with some multiply-adds. - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - // Each iteration of this loop advances by 4 levels of depth. - "add w1, w1, #4\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - // Loop termination condition. - "cmp w1, w12\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - - "blt 2b\n" - - "79:\n" - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last 4 levels of depth, for which the LHS - // and RHS data is already loaded. - - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v15.4s\n" - "add v20.4s, v20.4s, v14.4s\n" - "add v21.4s, v21.4s, v15.4s\n" - "add v22.4s, v22.4s, v14.4s\n" - "add v23.4s, v23.4s, v15.4s\n" - "add v24.4s, v24.4s, v14.4s\n" - "add v25.4s, v25.4s, v15.4s\n" - "add v26.4s, v26.4s, v14.4s\n" - "add v27.4s, v27.4s, v15.4s\n" - "add v28.4s, v28.4s, v14.4s\n" - "add v29.4s, v29.4s, v15.4s\n" - "add v30.4s, v30.4s, v14.4s\n" - "add v31.4s, v31.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3], #16\n" - "ld1 {v15.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "mls v18.4s, v10.4s, v14.s[1]\n" - "mls v19.4s, v10.4s, v14.s[1]\n" - "mls v20.4s, v10.4s, v14.s[2]\n" - "mls v21.4s, v10.4s, v14.s[2]\n" - "mls v22.4s, v10.4s, v14.s[3]\n" - "mls v23.4s, v10.4s, v14.s[3]\n" - "mls v24.4s, v10.4s, v15.s[0]\n" - "mls v25.4s, v10.4s, v15.s[0]\n" - "mls v26.4s, v10.4s, v15.s[1]\n" - "mls v27.4s, v10.4s, v15.s[1]\n" - "mls v28.4s, v10.4s, v15.s[2]\n" - "mls v29.4s, v10.4s, v15.s[2]\n" - "mls v30.4s, v10.4s, v15.s[3]\n" - "mls v31.4s, v10.4s, v15.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2], #16\n" - "ld1 {v12.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v12.4s\n" - "sub v20.4s, v20.4s, v11.4s\n" - "sub v21.4s, v21.4s, v12.4s\n" - "sub v22.4s, v22.4s, v11.4s\n" - "sub v23.4s, v23.4s, v12.4s\n" - "sub v24.4s, v24.4s, v11.4s\n" - "sub v25.4s, v25.4s, v12.4s\n" - "sub v26.4s, v26.4s, v11.4s\n" - "sub v27.4s, v27.4s, v12.4s\n" - "sub v28.4s, v28.4s, v11.4s\n" - "sub v29.4s, v29.4s, v12.4s\n" - "sub v30.4s, v30.4s, v11.4s\n" - "sub v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v11.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - "sshl v20.4s, v20.4s, v11.4s\n" - "sshl v21.4s, v21.4s, v12.4s\n" - "sshl v22.4s, v22.4s, v11.4s\n" - "sshl v23.4s, v23.4s, v12.4s\n" - "sshl v24.4s, v24.4s, v11.4s\n" - "sshl v25.4s, v25.4s, v12.4s\n" - "sshl v26.4s, v26.4s, v11.4s\n" - "sshl v27.4s, v27.4s, v12.4s\n" - "sshl v28.4s, v28.4s, v11.4s\n" - "sshl v29.4s, v29.4s, v12.4s\n" - "sshl v30.4s, v30.4s, v11.4s\n" - "sshl v31.4s, v31.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "sqrdmulh v18.4s, v18.4s, v14.4s\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - "sqrdmulh v20.4s, v20.4s, v14.4s\n" - "sqrdmulh v21.4s, v21.4s, v15.4s\n" - "sqrdmulh v22.4s, v22.4s, v14.4s\n" - "sqrdmulh v23.4s, v23.4s, v15.4s\n" - "sqrdmulh v24.4s, v24.4s, v14.4s\n" - "sqrdmulh v25.4s, v25.4s, v15.4s\n" - "sqrdmulh v26.4s, v26.4s, v14.4s\n" - "sqrdmulh v27.4s, v27.4s, v15.4s\n" - "sqrdmulh v28.4s, v28.4s, v14.4s\n" - "sqrdmulh v29.4s, v29.4s, v15.4s\n" - "sqrdmulh v30.4s, v30.4s, v14.4s\n" - "sqrdmulh v31.4s, v31.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v11.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" - "and v8.16b, v20.16b, v11.16b\n" - "and v9.16b, v21.16b, v12.16b\n" - "and v14.16b, v22.16b, v11.16b\n" - "and v15.16b, v23.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v20.4s, v20.4s, v8.4s\n" - "sqadd v21.4s, v21.4s, v9.4s\n" - "sqadd v22.4s, v22.4s, v14.4s\n" - "sqadd v23.4s, v23.4s, v15.4s\n" - "and v8.16b, v24.16b, v11.16b\n" - "and v9.16b, v25.16b, v12.16b\n" - "and v14.16b, v26.16b, v11.16b\n" - "and v15.16b, v27.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v24.4s, v24.4s, v8.4s\n" - "sqadd v25.4s, v25.4s, v9.4s\n" - "sqadd v26.4s, v26.4s, v14.4s\n" - "sqadd v27.4s, v27.4s, v15.4s\n" - "and v8.16b, v28.16b, v11.16b\n" - "and v9.16b, v29.16b, v12.16b\n" - "and v14.16b, v30.16b, v11.16b\n" - "and v15.16b, v31.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v28.4s, v28.4s, v8.4s\n" - "sqadd v29.4s, v29.4s, v9.4s\n" - "sqadd v30.4s, v30.4s, v14.4s\n" - "sqadd v31.4s, v31.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v11.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - "srshl v20.4s, v20.4s, v11.4s\n" - "srshl v21.4s, v21.4s, v12.4s\n" - "srshl v22.4s, v22.4s, v11.4s\n" - "srshl v23.4s, v23.4s, v12.4s\n" - "srshl v24.4s, v24.4s, v11.4s\n" - "srshl v25.4s, v25.4s, v12.4s\n" - "srshl v26.4s, v26.4s, v11.4s\n" - "srshl v27.4s, v27.4s, v12.4s\n" - "srshl v28.4s, v28.4s, v11.4s\n" - "srshl v29.4s, v29.4s, v12.4s\n" - "srshl v30.4s, v30.4s, v11.4s\n" - "srshl v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "sqxtun2 v16.16b, v17.8h\n" - "sqxtun v17.8b, v18.8h\n" - "sqxtun2 v17.16b, v19.8h\n" - "sqxtun v18.8b, v20.8h\n" - "sqxtun2 v18.16b, v21.8h\n" - "sqxtun v19.8b, v22.8h\n" - "sqxtun2 v19.16b, v23.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - "umax v17.16b, v17.16b, v14.16b\n" - "umax v18.16b, v18.16b, v14.16b\n" - "umax v19.16b, v19.16b, v14.16b\n" - - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - "umin v17.16b, v17.16b, v15.16b\n" - "umin v18.16b, v18.16b, v15.16b\n" - "umin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - "sqxtn2 v16.16b, v17.8h\n" - "sqxtn v17.8b, v18.8h\n" - "sqxtn2 v17.16b, v19.8h\n" - "sqxtn v18.8b, v20.8h\n" - "sqxtn2 v18.16b, v21.8h\n" - "sqxtn v19.8b, v22.8h\n" - "sqxtn2 v19.16b, v23.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - "smax v17.16b, v17.16b, v14.16b\n" - "smax v18.16b, v18.16b, v14.16b\n" - "smax v19.16b, v19.16b, v14.16b\n" - - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - "smin v17.16b, v17.16b, v15.16b\n" - "smin v18.16b, v18.16b, v15.16b\n" - "smin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 150b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - "saddw v20.4s, v20.4s, v14.4h\n" - "saddw v21.4s, v21.4s, v14.4h\n" - "saddw v22.4s, v22.4s, v14.4h\n" - "saddw v23.4s, v23.4s, v14.4h\n" - "saddw v24.4s, v24.4s, v14.4h\n" - "saddw v25.4s, v25.4s, v14.4h\n" - "saddw v26.4s, v26.4s, v14.4h\n" - "saddw v27.4s, v27.4s, v14.4h\n" - "saddw v28.4s, v28.4s, v14.4h\n" - "saddw v29.4s, v29.4s, v14.4h\n" - "saddw v30.4s, v30.4s, v14.4h\n" - "saddw v31.4s, v31.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - "smax v18.8h, v18.8h, v14.8h\n" - "smax v19.8h, v19.8h, v14.8h\n" - "smax v20.8h, v20.8h, v14.8h\n" - "smax v21.8h, v21.8h, v14.8h\n" - "smax v22.8h, v22.8h, v14.8h\n" - "smax v23.8h, v23.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - "smin v18.8h, v18.8h, v15.8h\n" - "smin v19.8h, v19.8h, v15.8h\n" - "smin v20.8h, v20.8h, v15.8h\n" - "smin v21.8h, v21.8h, v15.8h\n" - "smin v22.8h, v22.8h, v15.8h\n" - "smin v23.8h, v23.8h, v15.8h\n" - - // Compute how much of the 8x8 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 16bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 250b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x8 block of destination 32it values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x8 block fits. - // Write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "st1 {v16.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v16) - "st1 {v17.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v17) - "st1 {v18.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v18) - "st1 {v19.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v19) - "st1 {v20.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v20) - "st1 {v21.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v21) - "st1 {v22.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v22) - "st1 {v23.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v23) - "st1 {v24.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v24) - "st1 {v25.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v25) - "st1 {v26.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v26) - "st1 {v27.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v27) - "st1 {v28.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v28) - "st1 {v29.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v29) - "st1 {v30.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v30) - "st1 {v31.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v31) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x8 block fits. - "mov x4, %[dst_ptr]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.4s, v17.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.4s, v19.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v20.4s, v21.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v22.4s, v23.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v24.4s, v25.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v26.4s, v27.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v28.4s, v29.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v30.4s, v31.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 350b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Similar to the above 8-bit dotprod kernel, but specialized for the case of -// RHS cols == 1. -// Relevant target CPUs for this kernel include ARM Cortex-A76, -// since these are 64-bit, out-of-order and with dotprod support. -void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v15 are used to load int8 data from LHS and - // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and - // v3 are used to load a 4x8 block of RHS, like this: - // - // int8 RHS 4x1 block - // /-------\ - // |v2.b[0]| - // | ... | - // |v2.b[3]| - // \-------/ - // int8 LHS 8x4 block - // /---------------------\ /--------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0]| - // | ... ... | | ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0]| - // | ... ... | | ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3]| - // \---------------------/ \--------/ - // int32 accumulators 8x1 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for loading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Ordinary kernel inner loop (over depth), the simpler loop that the - // above was an equivalent 4x-partially-unrolled version of. - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Because of the data that we have already loaded, we can start the - // loop body right away with some multiply-adds. - // Each iteration of this loop advances by 4 levels of depth. - "add w1, w1, #4\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - // Loop termination condition. - "cmp w1, w12\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - - "blt 2b\n" - - "79:\n" - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last 4 levels of depth, for which the LHS - // and RHS data is already loaded. - - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3], #16\n" - "ld1 {v15.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2], #16\n" - "ld1 {v12.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - // All data in v16 at this point. - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8, leaving all data in the - // lower half of v16. - "sqxtun v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - - // Compute how much of the 8x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x1, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x1 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8b}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - - // Compute how much of the 8x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x1 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8b}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - - // Compute how much of the 8x1 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 16bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8h}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x1 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x1 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x1 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x1, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - - // Write our 32bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.4s}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.4s}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x4, %[dst_ptr]\n" - "mov x3, x4\n" - - // Write our 32bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.4s, v17.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17"); -} - -// Variant of the above Kernel8bitNeonDotprodOutOfOrder, tuned for in-order -// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1, -// since these are 64-bit and support dotprod. -// -// While this kernel does not have a direct equivalent in gemmlowp, it was -// developed based on insights that David Mansell at ARM shared with their -// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A55r1: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 -void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for in-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // RHS. - // - // int8 RHS 4x8 block - // /-----------------------------------------\ - // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| - // | ... ... | - // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| - // \-----------------------------------------/ - // int8 LHS 8x4 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| - // | ... ... | | ... ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| - // | ... ... | | ... ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for - // the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - RUY_MAKE_ZERO(v16) - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(v17) - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - RUY_MAKE_ZERO(v18) - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - RUY_MAKE_ZERO(v19) - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v20) - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v21) - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - RUY_MAKE_ZERO(v22) - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_MAKE_ZERO(v28) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_MAKE_ZERO(v29) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_MAKE_ZERO(v30) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_MAKE_ZERO(v31) - - - "1:\n" - - "add x5, %[lhs_ptr], x12, lsl #3\n" - "sub x5, x5, #32\n" - "cmp %[lhs_ptr], x5\n" - - "beq 79f\n" - - // Main accumulation loop - "2:\n" - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - "ldr x1, [%[lhs_ptr], #8]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - "ldr x3, [%[rhs_ptr], #8]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - "ldr x4, [%[rhs_ptr], #24]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ldr d0, [%[lhs_ptr], #0]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - "ins v0.d[1], x1\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - "ldr x2, [%[lhs_ptr], #24]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - "add %[lhs_ptr], %[lhs_ptr], #32\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ldr d2, [%[rhs_ptr], #0]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - "ins v2.d[1], x3\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - "cmp %[lhs_ptr], x5\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ldr d3, [%[rhs_ptr], #-16]\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - "ins v3.d[1], x4\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - "ins v1.d[1], x2\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - "blt 2b\n" - - // Last accumulation steps, nothing left to load. - "79:\n" - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - "cmp %w[row], w7\n" // Have we finished the last row? - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.2s}, [x1], #8\n" - "ldr x5, [x1], #8\n" - "ins v14.d[1], x5\n" - "ld1 {v15.2s}, [x1], #8\n" - "ldr x5, [x1], #8\n" - "ins v15.d[1], x5\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v15.4s\n" - "add v20.4s, v20.4s, v14.4s\n" - "add v21.4s, v21.4s, v15.4s\n" - "add v22.4s, v22.4s, v14.4s\n" - "add v23.4s, v23.4s, v15.4s\n" - "add v24.4s, v24.4s, v14.4s\n" - "add v25.4s, v25.4s, v15.4s\n" - "add v26.4s, v26.4s, v14.4s\n" - "add v27.4s, v27.4s, v15.4s\n" - "add v28.4s, v28.4s, v14.4s\n" - "add v29.4s, v29.4s, v15.4s\n" - "add v30.4s, v30.4s, v14.4s\n" - "add v31.4s, v31.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Load 8 rhs_sums values. - "ld1 {v14.2s}, [x3], #8\n" - "ldr x7, [x3], #8\n" - "ld1 {v15.2s}, [x3], #8\n" - "ins v14.d[1], x7\n" - "ldr x7, [x3], #8\n" - "ins v15.d[1], x7\n" - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "mls v18.4s, v10.4s, v14.s[1]\n" - "mls v19.4s, v10.4s, v14.s[1]\n" - "mls v20.4s, v10.4s, v14.s[2]\n" - "mls v21.4s, v10.4s, v14.s[2]\n" - "mls v22.4s, v10.4s, v14.s[3]\n" - "mls v23.4s, v10.4s, v14.s[3]\n" - "mls v24.4s, v10.4s, v15.s[0]\n" - "mls v25.4s, v10.4s, v15.s[0]\n" - "mls v26.4s, v10.4s, v15.s[1]\n" - "mls v27.4s, v10.4s, v15.s[1]\n" - "mls v28.4s, v10.4s, v15.s[2]\n" - "mls v29.4s, v10.4s, v15.s[2]\n" - "mls v30.4s, v10.4s, v15.s[3]\n" - "mls v31.4s, v10.4s, v15.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Load 8 lhs_sums values. - "ld1 {v11.2s}, [x2], #8\n" - "ldr x6, [x2], #8\n" - "ins v11.d[1], x6\n" - "ld1 {v12.2s}, [x2], #8\n" - "ldr x6, [x2], #8\n" - "ins v12.d[1], x6\n" - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v12.4s\n" - "sub v20.4s, v20.4s, v11.4s\n" - "sub v21.4s, v21.4s, v12.4s\n" - "sub v22.4s, v22.4s, v11.4s\n" - "sub v23.4s, v23.4s, v12.4s\n" - "sub v24.4s, v24.4s, v11.4s\n" - "sub v25.4s, v25.4s, v12.4s\n" - "sub v26.4s, v26.4s, v11.4s\n" - "sub v27.4s, v27.4s, v12.4s\n" - "sub v28.4s, v28.4s, v11.4s\n" - "sub v29.4s, v29.4s, v12.4s\n" - "sub v30.4s, v30.4s, v11.4s\n" - "sub v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v11.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - "sshl v20.4s, v20.4s, v11.4s\n" - "sshl v21.4s, v21.4s, v12.4s\n" - "sshl v22.4s, v22.4s, v11.4s\n" - "sshl v23.4s, v23.4s, v12.4s\n" - "sshl v24.4s, v24.4s, v11.4s\n" - "sshl v25.4s, v25.4s, v12.4s\n" - "sshl v26.4s, v26.4s, v11.4s\n" - "sshl v27.4s, v27.4s, v12.4s\n" - "sshl v28.4s, v28.4s, v11.4s\n" - "sshl v29.4s, v29.4s, v12.4s\n" - "sshl v30.4s, v30.4s, v11.4s\n" - "sshl v31.4s, v31.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - // - // ... and, interleaved into that: - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "ldr x1, [%[lhs_ptr]], #8\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" - "sqrdmulh v18.4s, v18.4s, v14.4s\n" - "ldr x2, [%[lhs_ptr]], #8\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" - "sqrdmulh v20.4s, v20.4s, v14.4s\n" - "ldr x5, [%[rhs_ptr]], #8\n" - "sqrdmulh v21.4s, v21.4s, v15.4s\n" - "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" - "sqrdmulh v22.4s, v22.4s, v14.4s\n" - "ldr x6, [%[rhs_ptr]], #8\n" - "sqrdmulh v23.4s, v23.4s, v15.4s\n" - "sqrdmulh v24.4s, v24.4s, v14.4s\n" - "sqrdmulh v25.4s, v25.4s, v15.4s\n" - "sqrdmulh v26.4s, v26.4s, v14.4s\n" - "sqrdmulh v27.4s, v27.4s, v15.4s\n" - "sqrdmulh v28.4s, v28.4s, v14.4s\n" - "sqrdmulh v29.4s, v29.4s, v15.4s\n" - "sqrdmulh v30.4s, v30.4s, v14.4s\n" - "sqrdmulh v31.4s, v31.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v11.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" - "and v8.16b, v20.16b, v11.16b\n" - "and v9.16b, v21.16b, v12.16b\n" - "and v14.16b, v22.16b, v11.16b\n" - "and v15.16b, v23.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v20.4s, v20.4s, v8.4s\n" - "sqadd v21.4s, v21.4s, v9.4s\n" - "sqadd v22.4s, v22.4s, v14.4s\n" - "sqadd v23.4s, v23.4s, v15.4s\n" - "and v8.16b, v24.16b, v11.16b\n" - "and v9.16b, v25.16b, v12.16b\n" - "and v14.16b, v26.16b, v11.16b\n" - "and v15.16b, v27.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v24.4s, v24.4s, v8.4s\n" - "sqadd v25.4s, v25.4s, v9.4s\n" - "sqadd v26.4s, v26.4s, v14.4s\n" - "sqadd v27.4s, v27.4s, v15.4s\n" - "and v8.16b, v28.16b, v11.16b\n" - "and v9.16b, v29.16b, v12.16b\n" - "and v14.16b, v30.16b, v11.16b\n" - "and v15.16b, v31.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v28.4s, v28.4s, v8.4s\n" - "sqadd v29.4s, v29.4s, v9.4s\n" - "sqadd v30.4s, v30.4s, v14.4s\n" - "sqadd v31.4s, v31.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v11.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - "srshl v20.4s, v20.4s, v11.4s\n" - "srshl v21.4s, v21.4s, v12.4s\n" - "srshl v22.4s, v22.4s, v11.4s\n" - "srshl v23.4s, v23.4s, v12.4s\n" - "srshl v24.4s, v24.4s, v11.4s\n" - "srshl v25.4s, v25.4s, v12.4s\n" - "srshl v26.4s, v26.4s, v11.4s\n" - "srshl v27.4s, v27.4s, v12.4s\n" - "ins v0.d[1], x1\n" - "srshl v28.4s, v28.4s, v11.4s\n" - "ins v1.d[1], x2\n" - "srshl v29.4s, v29.4s, v12.4s\n" - "ins v2.d[1], x5\n" - "srshl v30.4s, v30.4s, v11.4s\n" - "ins v3.d[1], x6\n" - "srshl v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // Destination zero_point - "dup v14.8h, v13.h[4]\n" - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "sqxtun2 v16.16b, v17.8h\n" - "sqxtun v17.8b, v18.8h\n" - "sqxtun2 v17.16b, v19.8h\n" - "sqxtun v18.8b, v20.8h\n" - "sqxtun2 v18.16b, v21.8h\n" - "sqxtun v19.8b, v22.8h\n" - "sqxtun2 v19.16b, v23.8h\n" - - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - "sub w2, %w[dst_cols], %w[col]\n" - "umax v17.16b, v17.16b, v14.16b\n" - "mov w3, #8\n" - "umax v18.16b, v18.16b, v14.16b\n" - "cmp w1, #8\n" - "umax v19.16b, v19.16b, v14.16b\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - "cmp w2, #8\n" - "umin v17.16b, v17.16b, v15.16b\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - "umin v18.16b, v18.16b, v15.16b\n" - "umin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // Destination zero_point - "dup v14.8h, v13.h[4]\n" - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "sqxtn2 v16.16b, v17.8h\n" - "sqxtn v17.8b, v18.8h\n" - "sqxtn2 v17.16b, v19.8h\n" - "sqxtn v18.8b, v20.8h\n" - "sqxtn2 v18.16b, v21.8h\n" - "sqxtn v19.8b, v22.8h\n" - "sqxtn2 v19.16b, v23.8h\n" - - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - "sub w2, %w[dst_cols], %w[col]\n" - "smax v17.16b, v17.16b, v14.16b\n" - "mov w3, #8\n" - "smax v18.16b, v18.16b, v14.16b\n" - "cmp w1, #8\n" - "smax v19.16b, v19.16b, v14.16b\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - "cmp w2, #8\n" - "smin v17.16b, v17.16b, v15.16b\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - "smin v18.16b, v18.16b, v15.16b\n" - "smin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 150b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - "saddw v20.4s, v20.4s, v14.4h\n" - "saddw v21.4s, v21.4s, v14.4h\n" - "saddw v22.4s, v22.4s, v14.4h\n" - "saddw v23.4s, v23.4s, v14.4h\n" - "saddw v24.4s, v24.4s, v14.4h\n" - "saddw v25.4s, v25.4s, v14.4h\n" - "saddw v26.4s, v26.4s, v14.4h\n" - "saddw v27.4s, v27.4s, v14.4h\n" - "saddw v28.4s, v28.4s, v14.4h\n" - "saddw v29.4s, v29.4s, v14.4h\n" - "saddw v30.4s, v30.4s, v14.4h\n" - "saddw v31.4s, v31.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - "smax v18.8h, v18.8h, v14.8h\n" - "smax v19.8h, v19.8h, v14.8h\n" - "smax v20.8h, v20.8h, v14.8h\n" - "smax v21.8h, v21.8h, v14.8h\n" - "smax v22.8h, v22.8h, v14.8h\n" - "smax v23.8h, v23.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - "smin v18.8h, v18.8h, v15.8h\n" - "smin v19.8h, v19.8h, v15.8h\n" - "smin v20.8h, v20.8h, v15.8h\n" - "smin v21.8h, v21.8h, v15.8h\n" - "smin v22.8h, v22.8h, v15.8h\n" - "smin v23.8h, v23.8h, v15.8h\n" - - // Compute how much of the 8x8 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 250b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" - "ldr x1, [%[lhs_ptr]], #8\n" - "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" - "ldr x2, [%[lhs_ptr]], #8\n" - "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" - "ldr x5, [%[rhs_ptr]], #8\n" - "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" - "ldr x6, [%[rhs_ptr]], #8\n" - "ins v0.d[1], x1\n" - "ins v1.d[1], x2\n" - "ins v2.d[1], x5\n" - "ins v3.d[1], x6\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x8 block of destination 32it values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x8 block fits. - // Write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "st1 {v16.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v16) - "st1 {v17.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v17) - "st1 {v18.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v18) - "st1 {v19.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v19) - "st1 {v20.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v20) - "st1 {v21.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v21) - "st1 {v22.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v22) - "st1 {v23.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v23) - "st1 {v24.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v24) - "st1 {v25.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v25) - "st1 {v26.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v26) - "st1 {v27.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v27) - "st1 {v28.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v28) - "st1 {v29.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v29) - "st1 {v30.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v30) - "st1 {v31.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v31) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x8 block fits. - "mov x4, %[dst_ptr]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.4s, v17.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v18.4s, v19.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v20.4s, v21.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v22.4s, v23.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v24.4s, v25.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v26.4s, v27.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v28.4s, v29.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v30.4s, v31.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 350b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_LHS_SUMS -#undef RUY_OFFSET_RHS_SUMS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT -#undef RUY_OFFSET_MULTIPLIER_EXPONENT -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_LHS_ZERO_POINT -#undef RUY_OFFSET_RHS_ZERO_POINT -#undef RUY_OFFSET_DST_ZERO_POINT -#undef RUY_OFFSET_PROD_ZP_DEPTH -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS - -#define RUY_OFFSET_LHS_BASE_PTR 0 -#define RUY_OFFSET_RHS_BASE_PTR 8 -#define RUY_OFFSET_DST_BASE_PTR 16 -#define RUY_OFFSET_BIAS 24 -#define RUY_OFFSET_START_ROW 32 -#define RUY_OFFSET_START_COL 36 -#define RUY_OFFSET_LAST_ROW 40 -#define RUY_OFFSET_LAST_COL 44 -#define RUY_OFFSET_LHS_STRIDE 56 -#define RUY_OFFSET_RHS_STRIDE 60 -#define RUY_OFFSET_DST_STRIDE 64 -#define RUY_OFFSET_DEPTH 68 -#define RUY_OFFSET_CLAMP_MIN 72 -#define RUY_OFFSET_CLAMP_MAX 76 -#define RUY_OFFSET_FLAGS 80 - -template -void CheckOffsetsInKernelParamsFloat(const Params&) { - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); - static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); -} - -// Just a plain float kernel; good enough for out-of-order cores. -// The closest to it in the gemmlowp collection would be -// NEON_64bit_GEMM_Float32_WithScalar, -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925 -// -// Besides ruy-ification, the main nuance here is that we stick to a 8x8 -// width instead of the wider 12x8 that the register space permits and that -// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now -// and we don't have evidence that going beyond 8x8 is needed. -void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params) { - CheckOffsetsInKernelParamsFloat(params); - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v15 are used to load data from LHS and RHS. - // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and - // v3 are used to load a 1x8 block of RHS, like this: - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for floading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov w1, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "cmp w12, #8\n" - "blt 78f\n" - "and w2, w12, #-4\n" - - "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v5.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" - - "ld1 {v8.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v9.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v10.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v11.4s}, [%[rhs_ptr]], #16\n" - - "ld1 {v12.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v13.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v14.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v15.4s}, [%[rhs_ptr]], #16\n" - "mov w1, #4\n" - - "80:\n" - - "add %[lhs_ptr], %[lhs_ptr], #128\n" - "add %[rhs_ptr], %[rhs_ptr], #128\n" - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr q0, [%[lhs_ptr], #-128]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr q3, [%[rhs_ptr], #-112]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "ldr q1, [%[lhs_ptr], #-112]\n" - "fmla v16.4s, v4.4s, v6.s[0]\n" - "fmla v18.4s, v4.4s, v6.s[1]\n" - "ldr q2, [%[rhs_ptr], #-128]\n" - "fmla v20.4s, v4.4s, v6.s[2]\n" - "fmla v22.4s, v4.4s, v6.s[3]\n" - - "fmla v24.4s, v4.4s, v7.s[0]\n" - "fmla v26.4s, v4.4s, v7.s[1]\n" - "fmla v28.4s, v4.4s, v7.s[2]\n" - "fmla v30.4s, v4.4s, v7.s[3]\n" - "ldr q4, [%[lhs_ptr], #-96]\n" - "fmla v25.4s, v5.4s, v7.s[0]\n" - "fmla v27.4s, v5.4s, v7.s[1]\n" - "fmla v29.4s, v5.4s, v7.s[2]\n" - "fmla v31.4s, v5.4s, v7.s[3]\n" - "ldr q7, [%[rhs_ptr], #-80]\n" - "fmla v17.4s, v5.4s, v6.s[0]\n" - "fmla v19.4s, v5.4s, v6.s[1]\n" - "fmla v21.4s, v5.4s, v6.s[2]\n" - "fmla v23.4s, v5.4s, v6.s[3]\n" - "ldr q5, [%[lhs_ptr], #-80]\n" - "fmla v16.4s, v8.4s, v10.s[0]\n" - "fmla v18.4s, v8.4s, v10.s[1]\n" - "ldr q6, [%[rhs_ptr], #-96]\n" - "fmla v20.4s, v8.4s, v10.s[2]\n" - "fmla v22.4s, v8.4s, v10.s[3]\n" - - "fmla v24.4s, v8.4s, v11.s[0]\n" - "fmla v26.4s, v8.4s, v11.s[1]\n" - "fmla v28.4s, v8.4s, v11.s[2]\n" - "fmla v30.4s, v8.4s, v11.s[3]\n" - "ldr q8, [%[lhs_ptr], #-64]\n" - "fmla v25.4s, v9.4s, v11.s[0]\n" - "fmla v27.4s, v9.4s, v11.s[1]\n" - "fmla v29.4s, v9.4s, v11.s[2]\n" - "fmla v31.4s, v9.4s, v11.s[3]\n" - "ldr q11, [%[rhs_ptr], #-48]\n" - "fmla v17.4s, v9.4s, v10.s[0]\n" - "fmla v19.4s, v9.4s, v10.s[1]\n" - "fmla v21.4s, v9.4s, v10.s[2]\n" - "fmla v23.4s, v9.4s, v10.s[3]\n" - "ldr q9, [%[lhs_ptr], #-48]\n" - "fmla v16.4s, v12.4s, v14.s[0]\n" - "fmla v18.4s, v12.4s, v14.s[1]\n" - "ldr q10, [%[rhs_ptr], #-64]\n" - "fmla v20.4s, v12.4s, v14.s[2]\n" - "fmla v22.4s, v12.4s, v14.s[3]\n" - - "fmla v24.4s, v12.4s, v15.s[0]\n" - "fmla v26.4s, v12.4s, v15.s[1]\n" - "fmla v28.4s, v12.4s, v15.s[2]\n" - "fmla v30.4s, v12.4s, v15.s[3]\n" - "ldr q12, [%[lhs_ptr], #-32]\n" - "fmla v25.4s, v13.4s, v15.s[0]\n" - "fmla v27.4s, v13.4s, v15.s[1]\n" - "fmla v29.4s, v13.4s, v15.s[2]\n" - "fmla v31.4s, v13.4s, v15.s[3]\n" - "ldr q15, [%[rhs_ptr], #-16]\n" - "fmla v17.4s, v13.4s, v14.s[0]\n" - "fmla v19.4s, v13.4s, v14.s[1]\n" - "fmla v21.4s, v13.4s, v14.s[2]\n" - "fmla v23.4s, v13.4s, v14.s[3]\n" - "ldr q13, [%[lhs_ptr], #-16]\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "ldr q14, [%[rhs_ptr], #-32]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - "add w1, w1, #4\n" - "cmp w1, w2\n" - "blt 80b\n" - - "fmla v16.4s, v4.4s, v6.s[0]\n" - "fmla v18.4s, v4.4s, v6.s[1]\n" - "fmla v20.4s, v4.4s, v6.s[2]\n" - "fmla v22.4s, v4.4s, v6.s[3]\n" - "fmla v24.4s, v4.4s, v7.s[0]\n" - "fmla v26.4s, v4.4s, v7.s[1]\n" - "fmla v28.4s, v4.4s, v7.s[2]\n" - "fmla v30.4s, v4.4s, v7.s[3]\n" - "fmla v25.4s, v5.4s, v7.s[0]\n" - "fmla v27.4s, v5.4s, v7.s[1]\n" - "fmla v29.4s, v5.4s, v7.s[2]\n" - "fmla v31.4s, v5.4s, v7.s[3]\n" - "fmla v17.4s, v5.4s, v6.s[0]\n" - "fmla v19.4s, v5.4s, v6.s[1]\n" - "fmla v21.4s, v5.4s, v6.s[2]\n" - "fmla v23.4s, v5.4s, v6.s[3]\n" - - "fmla v16.4s, v8.4s, v10.s[0]\n" - "fmla v18.4s, v8.4s, v10.s[1]\n" - "fmla v20.4s, v8.4s, v10.s[2]\n" - "fmla v22.4s, v8.4s, v10.s[3]\n" - "fmla v24.4s, v8.4s, v11.s[0]\n" - "fmla v26.4s, v8.4s, v11.s[1]\n" - "fmla v28.4s, v8.4s, v11.s[2]\n" - "fmla v30.4s, v8.4s, v11.s[3]\n" - "fmla v25.4s, v9.4s, v11.s[0]\n" - "fmla v27.4s, v9.4s, v11.s[1]\n" - "fmla v29.4s, v9.4s, v11.s[2]\n" - "fmla v31.4s, v9.4s, v11.s[3]\n" - "fmla v17.4s, v9.4s, v10.s[0]\n" - "fmla v19.4s, v9.4s, v10.s[1]\n" - "fmla v21.4s, v9.4s, v10.s[2]\n" - "fmla v23.4s, v9.4s, v10.s[3]\n" - - "fmla v16.4s, v12.4s, v14.s[0]\n" - "fmla v18.4s, v12.4s, v14.s[1]\n" - "fmla v20.4s, v12.4s, v14.s[2]\n" - "fmla v22.4s, v12.4s, v14.s[3]\n" - "fmla v24.4s, v12.4s, v15.s[0]\n" - "fmla v26.4s, v12.4s, v15.s[1]\n" - "fmla v28.4s, v12.4s, v15.s[2]\n" - "fmla v30.4s, v12.4s, v15.s[3]\n" - "fmla v25.4s, v13.4s, v15.s[0]\n" - "fmla v27.4s, v13.4s, v15.s[1]\n" - "fmla v29.4s, v13.4s, v15.s[2]\n" - "fmla v31.4s, v13.4s, v15.s[3]\n" - "fmla v17.4s, v13.4s, v14.s[0]\n" - "fmla v19.4s, v13.4s, v14.s[1]\n" - "fmla v21.4s, v13.4s, v14.s[2]\n" - "fmla v23.4s, v13.4s, v14.s[3]\n" - - "78:\n" -#endif - - // Accumulation loop - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "add w1, w1, #1\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "cmp w1, w12\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "mov v2.16b, v4.16b\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "fmla v22.4s, v0.4s, v4.s[3]\n" - "blt 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov w1, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Variant of KernelFloatNeonOutOfOrder tuned for in-order CPUs that do not -// support dotprod (while dotprod by itself is not relevant to floating-point, -// this additional bit of information that we have about the target happens to -// be useful here). -// -// So a typical target CPU here would be ARM Cortex-A53 or the original -// Cortex-A55. -// -// This kernel is similar to and inspired by gemmlowp's -// NEON_64bit_GEMM_Float32_WithScalar_A53. -// which was contributed by David Mansell with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A53: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 -void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); - - CheckOffsetsInKernelParamsFloat(params); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v3 are used to load data from LHS and RHS. - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used - // for the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v17) - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v18) - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v19) - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") - RUY_MAKE_ZERO(v24) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") - RUY_MAKE_ZERO(v26) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "cmp w1, #0\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - // Accumulation loop - "beq 79f\n" - - "2:\n" - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "ldr x2, [%[lhs_ptr], #8]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ldr x3, [%[lhs_ptr], #24]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "ldr x5, [%[rhs_ptr], #24]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr x4, [%[rhs_ptr], #8]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "subs w1, w1, #1\n" - "ldr d0, [%[lhs_ptr]], #32\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ins v0.d[1], x2\n" - "ldr d3, [%[rhs_ptr], #16]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "ins v3.d[1], x5\n" - "ldr d4, [%[rhs_ptr]], #32\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "ins v4.d[1], x4\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "ins v1.d[1], x3\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - "mov v2.16b, v4.16b\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - "fmla v22.4s, v0.4s, v4.s[3]\n" - "bne 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Variant of KernelFloatNeonInOrder tuned for in-order CPUs that do -// support dotprod (while dotprod by itself is not relevant to floating-point, -// this additional bit of information that we have about the target happens to -// be useful here). -// -// So a typical target CPU here would be ARM Cortex-A55r1. -// -// This kernel is similar to and inspired by gemmlowp's -// NEON_64bit_GEMM_Float32_WithScalar_A55r1. -// which was contributed by David Mansell with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A55r1: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 -void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for in-order cores)"); - - CheckOffsetsInKernelParamsFloat(params); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v3 are used to load data from LHS and RHS. - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used - // for the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v17) - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v18) - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v19) - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") - RUY_MAKE_ZERO(v24) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") - RUY_MAKE_ZERO(v26) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "cmp w1, #0\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - // Accumulation loop - "beq 79f\n" - - "2:\n" - - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - "fmla v24.4s, v0.4s, v3.s[0]\n" - "ldr x2, [%[lhs_ptr], #8]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ldr x3, [%[lhs_ptr], #24]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "ldr x5, [%[rhs_ptr], #24]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr d0, [%[lhs_ptr]], #32\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "ldr x4, [%[rhs_ptr], #8]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "subs w1, w1, #1\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "ins v0.d[1], x2\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr d3, [%[rhs_ptr], #16]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "ins v3.d[1], x5\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "ldr d4, [%[rhs_ptr]], #32\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "ins v4.d[1], x4\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - "fmla v16.4s, v0.4s, v4.s[0]\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "ins v1.d[1], x3\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "mov v2.16b, v4.16b\n" - "fmla v22.4s, v0.4s, v4.s[3]\n" - "bne 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_FLAGS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR - -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc deleted file mode 100644 index 1113469fd28..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_avx2.cc +++ /dev/null @@ -1,1664 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvx8bitBlockSize = 8; -static constexpr int kAvx8bitInnerSize = 4; - -namespace { -namespace intrin_utils { - -inline __m256 mm256_n_loadu_epi32(int n, const std::int32_t* src) { - switch (n) { - case 0: - return _mm256_setzero_si256(); - case 1: - return _mm256_setr_m128(_mm_setr_epi32(src[0], 0, 0, 0), - _mm_setzero_si128()); - case 2: - return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], 0, 0), - _mm_setzero_si128()); - case 3: - return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], src[2], 0), - _mm_setzero_si128()); - case 4: - return _mm256_castsi128_si256( - _mm_loadu_si128(reinterpret_cast<__m128i const*>(src))); - case 5: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], 0, 0, 0); - case 6: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], - 0, 0); - case 7: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], 0); - case 8: - return _mm256_loadu_si256(reinterpret_cast<__m256i const*>(src)); - default: - RUY_DCHECK_LT(n, 9); - return _mm256_setzero_si256(); - } -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, - const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - __m256i shuffled_v; - if (residual_rows > 1) { - // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 - // in each 128-bit lane. - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - } - switch (residual_rows) { - case 0: - break; - case 1: - dst[0] = _mm256_extract_epi8(v, 0); - break; - case 2: - _mm_storeu_si16(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 3: { - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 0); - _mm_storeu_si16(dst, trailing_packed); - dst[2] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 4: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 5: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - dst[4] = _mm256_extract_epi8(shuffled_v, 16); - break; - case 6: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si16(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - _mm_storeu_si16(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 8: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, - const __m256 v) { - intrin_utils::mm256_n_storeu_cvtepi32_epi8( - reinterpret_cast(dst), residual_rows, v); -} - -inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, - const __m256 v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - __m256i shuffled_v; - __m128i shuffled_v_low; - if (residual_rows > 1) { - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - shuffled_v_low = _mm256_extracti128_si256(shuffled_v, 0); - } else { - shuffled_v_low = _mm256_extracti128_si256(v, 0); - } - switch (residual_rows) { - case 0: - break; - case 1: - _mm_storeu_si16(dst, shuffled_v_low); - break; - case 2: - _mm_storeu_si32(dst, shuffled_v_low); - break; - case 3: { - _mm_storeu_si32(dst, shuffled_v_low); - dst[2] = _mm_extract_epi16(shuffled_v_low, 2); - break; - } - case 4: - _mm_storeu_si64(dst, shuffled_v_low); - break; - case 5: - _mm_storeu_si64(dst, shuffled_v_low); - dst[4] = _mm256_extract_epi16(shuffled_v, 8); - break; - case 6: - _mm_storeu_si64(dst, shuffled_v_low); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - _mm_storeu_si64(dst, shuffled_v_low); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - _mm_storeu_si32(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi16(trailing_packed, 2); - break; - } - case 8: - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256 v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, - const __m256 v) { - const __m128i v_low = _mm256_extracti128_si256(v, 0); - switch (residual_rows) { - case 0: - break; - case 1: - _mm_storeu_si32(dst, v_low); - break; - case 2: - _mm_storeu_si64(dst, v_low); - break; - case 3: { - __m128i trailing_packed = v_low; - _mm_storeu_si64(dst, trailing_packed); - dst[2] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 4: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - break; - case 5: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - dst[4] = _mm256_extract_epi32(v, 4); - break; - case 6: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(v, 1)); - break; - case 7: { - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - __m128i trailing_packed = _mm256_extracti128_si256(v, 1); - _mm_storeu_si64(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 8: - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_epi32(std::int32_t* dst, const __m256 v) { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); -} - -inline float mm256_get1_ps(const __m256 a, int i) { - __m256i ai = _mm256_castps_si256(a); - int float_val_as_int; - switch (i) { - case 0: - float_val_as_int = _mm256_extract_epi32(ai, 0); - break; - case 1: - float_val_as_int = _mm256_extract_epi32(ai, 1); - break; - case 2: - float_val_as_int = _mm256_extract_epi32(ai, 2); - break; - case 3: - float_val_as_int = _mm256_extract_epi32(ai, 3); - break; - case 4: - float_val_as_int = _mm256_extract_epi32(ai, 4); - break; - case 5: - float_val_as_int = _mm256_extract_epi32(ai, 5); - break; - case 6: - float_val_as_int = _mm256_extract_epi32(ai, 6); - break; - case 7: - float_val_as_int = _mm256_extract_epi32(ai, 7); - break; - default: - RUY_DCHECK_LT(i, 8); - return .0f; - } - return reinterpret_cast(float_val_as_int); -} - -inline __m256 mm256_n_loadu_ps(int i, const float* src) { - switch (i) { - case 0: - return _mm256_setzero_ps(); - case 1: - return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f), - _mm_setzero_ps()); - case 2: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f), - _mm_setzero_ps()); - case 3: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f), - _mm_setzero_ps()); - case 4: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]), - _mm_setzero_ps()); - case 5: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f, - .0f); - case 6: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f, - .0f); - case 7: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], .0f); - case 8: - return _mm256_loadu_ps(src); - default: - RUY_DCHECK_LT(i, 9); - return _mm256_setzero_ps(); - } -} - -inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { - for (int i = 0; i < residual_rows; ++i) { - dst[i] = intrin_utils::mm256_get1_ps(v, i); - } -} -} // namespace intrin_utils -} // namespace - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit"); - const std::int8_t splitter_idx_data[32] = { - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15, // - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15 // - }; - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[8]; - if (has_rhs_sums_offsets) { - const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( - _mm256_set1_epi32(lhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - const __m256i splitter_idx = _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(splitter_idx_data)); - - __m256i accum_data_v0; - __m256i accum_data_v1; - __m256i accum_data_v2; - __m256i accum_data_v3; - __m256i accum_data_v4; - __m256i accum_data_v5; - __m256i accum_data_v6; - __m256i accum_data_v7; - - // Initialize with bias. - __m256i initial_accum_data = - intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - // Adjustments common across columns. - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m256i lhs_sums_offset = _mm256_mullo_epi32( - _mm256_set1_epi32(rhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); - initial_accum_data = - _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); - } - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth) { - initial_accum_data = _mm256_add_epi32(initial_accum_data, - _mm256_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); - accum_data_v1 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); - accum_data_v2 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); - accum_data_v3 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); - accum_data_v4 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); - accum_data_v5 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); - accum_data_v6 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); - accum_data_v7 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); - } else { - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - const __m256i lhs_data = - _mm256_load_si256(reinterpret_cast(lhs_ptr)); - const __m256i rhs_data_8bit = - _mm256_load_si256(reinterpret_cast(rhs_ptr)); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[16]; - const __m128i rhs_data_bottom_lane = - _mm256_castsi256_si128(rhs_data_8bit); - const __m128i rhs_data_top_lane = - _mm256_extracti128_si256(rhs_data_8bit, 1); - const __m256i rhs_16_bit_dup_low = - _mm256_cvtepi8_epi16(rhs_data_bottom_lane); - const __m256i rhs_16_bit_dup_high = - _mm256_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), - rhs_16_bit_dup_high); - - // NOTE: There may be opportunities for permuting the data in the - // packing code instead of here. - const __m256i lhs_data_split = - _mm256_shuffle_epi8(lhs_data, splitter_idx); - const __m256i lhs_data_split_expand_bottom = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); - const __m256i lhs_data_split_expand_top = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); - - // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. - const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); - // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. - const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); - // Accumulate for column 0. - { - const std::int32_t low_rhs_value = rhs_data[0]; - const std::int32_t high_rhs_value = rhs_data[1]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 1. - { - const std::int32_t low_rhs_value = rhs_data[2]; - const std::int32_t high_rhs_value = rhs_data[3]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 2. - { - const std::int32_t low_rhs_value = rhs_data[4]; - const std::int32_t high_rhs_value = rhs_data[5]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 3. - { - const std::int32_t low_rhs_value = rhs_data[6]; - const std::int32_t high_rhs_value = rhs_data[7]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 4. - { - const std::int32_t low_rhs_value = rhs_data[8]; - const std::int32_t high_rhs_value = rhs_data[9]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 5. - { - const std::int32_t low_rhs_value = rhs_data[10]; - const std::int32_t high_rhs_value = rhs_data[11]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 6. - { - const std::int32_t low_rhs_value = rhs_data[12]; - const std::int32_t high_rhs_value = rhs_data[13]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 7. - { - const std::int32_t low_rhs_value = rhs_data[14]; - const std::int32_t high_rhs_value = rhs_data[15]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m256i m_vector; - __m256i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_fixedpoint[row]); - e_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); - } - - const __m256i m_64bit_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); - const __m256i m_64bit_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); - - const __m256i zero_vector = _mm256_setzero_si256(); - const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); - const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); - const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); - const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(final_right_shift, 0)); - const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); - const __m256i convert_to_unsigned_64 = - _mm256_set1_epi64x(0x8000000000000000); - - __m256i post_scaling_offset = _mm256_add_epi32( - convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 1))), - convert_to_unsigned_64); - - if (params.dst_zero_point) { - const __m256i dst_zero_point = - _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); - } - -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); - - // We cannot do - // - // scaled_v_low = - // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); - // scaled_v_high = - // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); - // - // since this instruction is not in AVX2. Instead we use - // _mm256_srlv_epi64, but this is an unsigned shift, so we applied - // offsets before (convert_to_unsigned_64) and after - // (convert_to_signed_halved). - // - // The overall process is, for 64-bit scaled accumulator: - // unsigned_accum = signed_accum + 1 << 63; - // unsigned_accum = (unsigned_accum >> right_shift) >> 31; - // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; - - // There are various ways to repack the results, in the absence of - // _mm256_cvtepi64_epi32() or anything like it. - // A. - // accum_data_v[j] = - // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), - // _mm256_extract_epi32(scaled_v_high, 4), - // _mm256_extract_epi32(scaled_v_high, 2), - // _mm256_extract_epi32(scaled_v_high, 0), - // _mm256_extract_epi32(scaled_v_low, 6), - // _mm256_extract_epi32(scaled_v_low, 4), - // _mm256_extract_epi32(scaled_v_low, 2), - // _mm256_extract_epi32(scaled_v_low, 0)); - // B. - // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); - // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); - // accum_data_v[j] = - // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), - // _mm256_extract_epi64(scaled_v_high, 0), - // _mm256_extract_epi64(scaled_v_low, 2), - // _mm256_extract_epi64(scaled_v_low, 0)); - // C. - // scaled_v_low = - // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); - // scaled_v_high = - // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); - // accum_data_v[j] = - // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); - // - // However, we choose the following because it uses two lighter - // instructions. The permutation does have a longer latency, but this - // loop can be unrolled. - // D. - // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - // __m256i results = - // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - // results = _mm256_permutevar8x32_epi32(results, repack_perm); - // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset); - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset); - } - } - const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); - const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - __m256i accum_data_v[kAvx8bitBlockSize]; - if (!store_full_block) { - accum_data_v[0] = accum_data_v0; - accum_data_v[1] = accum_data_v1; - accum_data_v[2] = accum_data_v2; - accum_data_v[3] = accum_data_v3; - accum_data_v[4] = accum_data_v4; - accum_data_v[5] = accum_data_v5; - accum_data_v[6] = accum_data_v6; - accum_data_v[7] = accum_data_v7; - } - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0 * dst_stride], - accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[1 * dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[dst_stride], accum_data_v1); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v[j]); - dst_block_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - const std::int8_t splitter_idx_data[32] = { - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15, // - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15 // - }; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[8]; - if (has_rhs_sums_offsets) { - const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( - _mm256_set1_epi32(lhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - - const __m256i splitter_idx = - _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); - - __m256i accum_data_v0; - - // Initialize with bias. - __m256i initial_accum_data = - intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - // Adjustments common across columns. - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m256i lhs_sums_offset = _mm256_mullo_epi32( - _mm256_set1_epi32(rhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); - initial_accum_data = - _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); - } - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth) { - initial_accum_data = _mm256_add_epi32(initial_accum_data, - _mm256_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm256_sub_epi32(initial_accum_data, - _mm256_set1_epi32(rhs_sums_offsets[0])); - } else { - accum_data_v0 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - const __m256i lhs_data = - _mm256_load_si256(reinterpret_cast(lhs_ptr)); - const __m128i rhs_data_8bit = _mm_loadu_si32(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - // For simplicity we load 4x the data that we need and process twice the - // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); - - // NOTE: There may be opportunities for permuting the data in the packing - // code instead of here. - const __m256i lhs_data_split = - _mm256_shuffle_epi8(lhs_data, splitter_idx); - const __m256i lhs_data_split_expand_bottom = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); - const __m256i lhs_data_split_expand_top = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); - - // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. - const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); - // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. - const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); - // Accumulate for column 0. - const std::int32_t low_rhs_value = rhs_data[0]; - const std::int32_t high_rhs_value = rhs_data[1]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m256i m_vector; - __m256i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_fixedpoint[row]); - e_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); - } - - const __m256i m_64bit_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); - const __m256i m_64bit_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); - - const __m256i zero_vector = _mm256_setzero_si256(); - const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); - const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); - const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); - const __m256i final_right_shift_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); - const __m256i final_right_shift_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); - const __m256i convert_to_unsigned_64 = - _mm256_set1_epi64x(0x8000000000000000); - - __m256i post_scaling_offset = - _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))), - convert_to_unsigned_64); - - if (params.dst_zero_point) { - const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); - } - -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); - - // See GEMM version for details of this process. - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); - } - } - const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); - const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v0); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; -} // NOLINT(readability/fn_size) - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float"); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - const std::int64_t dst_stride = params.dst_stride >> 2; - const std::int64_t rhs_stride = params.rhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - const int end_col = std::min(params.dst_cols, params.last_col + 8); - // - const float* adj_rhs_col_ptr = - params.rhs_base_ptr - params.start_col * rhs_stride; - float* adj_dst_col_ptr = - params.dst_base_ptr - params.start_col * dst_stride - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - - int col = params.start_col; - // Loop through cols by float block size, leaving incomplete remainder - for (; col <= end_col - 8; col += 8) { - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than first - // loading together and then extract with broadcasting. This is because - // AVX flavours and instrinsics and compilers in combination do not - // handle this pattern of extraction very well. - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - if (residual_rows == 8) { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - _mm256_storeu_ps(block_ptr, accum_data_v[j]); - } - } else { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, - accum_data_v[j]); - } - } - } // End row-block loop. - } // End col-block loop. - - if (col < end_col) { - // Remaining cols in [0, float block size). - RUY_DCHECK_GE(end_col - col, 0); - RUY_DCHECK_LT(end_col - col, 8); - - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - const int residual_cols = std::min(end_col - col, 8); - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, - accum_data_v[j]); - } - } // End row-block loop. - } // End col-block terminal conditional. -} - -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - - float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - - __m256 accum_data_v; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = adj_dst_col_ptr; - - int row = params.start_row; - for (; row <= end_row - 8; row += 8) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm256_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - int d = 0; - for (; d <= params.depth - 4; d += 4) { - const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); - const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); - accum_data_v = - _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v); - const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); - const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); - accum_data_v = - _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v); - - const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); - const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); - accum_data_v = - _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v); - const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); - const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); - accum_data_v = - _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v); - lhs_ptr += 32; // Loaded 8 * 4 floats. - rhs_ptr += 32; - } - for (; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - _mm256_storeu_ps(dst_ptr, accum_data_v); - } // End row-block loop. - - if (row < end_row) { - const int residual_rows = end_row - row; - RUY_CHECK_GE(residual_rows, 1); - RUY_CHECK_LT(residual_rows, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v); - } // End handling of residual rows. -} - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc deleted file mode 100644 index e51876fcc02..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_avx512.cc +++ /dev/null @@ -1,1820 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 8-bit"); - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; col += 16) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[16]; - if (has_rhs_sums_offsets) { - const __m512i rhs_sums_offset_v = - _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), - _mm512_loadu_epi32(¶ms.rhs_sums[col])); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; row += 16) { - const int residual_rows = std::min(params.dst_rows - row, 16); - const int residual_cols = std::min(params.dst_cols - col, 16); - - __m512i accum_data_v0; - __m512i accum_data_v1; - __m512i accum_data_v2; - __m512i accum_data_v3; - __m512i accum_data_v4; - __m512i accum_data_v5; - __m512i accum_data_v6; - __m512i accum_data_v7; - __m512i accum_data_v8; - __m512i accum_data_v9; - __m512i accum_data_va; - __m512i accum_data_vb; - __m512i accum_data_vc; - __m512i accum_data_vd; - __m512i accum_data_ve; - __m512i accum_data_vf; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m512i lhs_sums_offset = - _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), - _mm512_loadu_epi32(¶ms.lhs_sums[row])); - initial_accum_data = - _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); - } - - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth != 0) { - initial_accum_data = _mm512_add_epi32(initial_accum_data, - _mm512_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); - accum_data_v1 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); - accum_data_v2 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); - accum_data_v3 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); - accum_data_v4 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); - accum_data_v5 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); - accum_data_v6 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); - accum_data_v7 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); - accum_data_v8 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); - accum_data_v9 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); - accum_data_va = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); - accum_data_vb = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); - accum_data_vc = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); - accum_data_vd = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); - accum_data_ve = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); - accum_data_vf = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); - } else { - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - accum_data_v8 = initial_accum_data; - accum_data_v9 = initial_accum_data; - accum_data_va = initial_accum_data; - accum_data_vb = initial_accum_data; - accum_data_vc = initial_accum_data; - accum_data_vd = initial_accum_data; - accum_data_ve = initial_accum_data; - accum_data_vf = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += 4) { - const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); - __m512i rhs_data_8bit = _mm512_loadu_epi8(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[32]; - const __m256i rhs_data_bottom_lane = - _mm512_castsi512_si256(rhs_data_8bit); - const __m256i rhs_data_top_lane = - _mm512_extracti32x8_epi32(rhs_data_8bit, 1); - const __m512i rhs_16_bit_dup_low = - _mm512_cvtepi8_epi16(rhs_data_bottom_lane); - const __m512i rhs_16_bit_dup_high = - _mm512_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), - rhs_16_bit_dup_high); - - // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. - const __m512i lhs_16_bit_low = - _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); - // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. - const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( - _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); - - // Process column 0. - { - __m512i accum_v = accum_data_v0; - constexpr int index = 0; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v0 = accum_v; - } - // Process column 1. - { - __m512i accum_v = accum_data_v1; - constexpr int index = 2; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v1 = accum_v; - } - // Process column 2. - { - __m512i accum_v = accum_data_v2; - constexpr int index = 4; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v2 = accum_v; - } - // Process column 3. - { - __m512i accum_v = accum_data_v3; - constexpr int index = 6; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v3 = accum_v; - } - // Process column 4. - { - __m512i accum_v = accum_data_v4; - constexpr int index = 8; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v4 = accum_v; - } - // Process column 5. - { - __m512i accum_v = accum_data_v5; - constexpr int index = 10; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v5 = accum_v; - } - // Process column 6. - { - __m512i accum_v = accum_data_v6; - constexpr int index = 12; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v6 = accum_v; - } - // Process column 7. - { - __m512i accum_v = accum_data_v7; - constexpr int index = 14; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v7 = accum_v; - } - // Process column 8. - { - __m512i accum_v = accum_data_v8; - constexpr int index = 16; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v8 = accum_v; - } - // Process column 9. - { - __m512i accum_v = accum_data_v9; - constexpr int index = 18; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v9 = accum_v; - } - // Process column 10. - { - __m512i accum_v = accum_data_va; - constexpr int index = 20; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_va = accum_v; - } - // Process column 11. - { - __m512i accum_v = accum_data_vb; - constexpr int index = 22; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vb = accum_v; - } - // Process column 12. - { - __m512i accum_v = accum_data_vc; - constexpr int index = 24; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vc = accum_v; - } - // Process column 13. - { - __m512i accum_v = accum_data_vd; - constexpr int index = 26; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vd = accum_v; - } - // Process column 14. - { - __m512i accum_v = accum_data_ve; - constexpr int index = 28; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_ve = accum_v; - } - // Process column 15. - { - __m512i accum_v = accum_data_vf; - constexpr int index = 30; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vf = accum_v; - } - - lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m512i m_vector; - __m512i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = _mm512_maskz_loadu_epi32( - row_mask, ¶ms.multiplier_fixedpoint[row]); - e_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); - } - - const __m512i m_64bit_low = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); - const __m512i m_64bit_high = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); - - const __m512i zero_vector = _mm512_setzero_epi32(); - const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); - const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); - const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); - const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 0)); - const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 1)); - - const __m512i offset_vector = - _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); - - // Shift and round column 0. - { - accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v0, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v0, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v0 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v0 = _mm512_inserti32x8( - accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 1. - { - accum_data_v1 = _mm512_sllv_epi32(accum_data_v1, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v1, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v1, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v1 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v1 = _mm512_inserti32x8( - accum_data_v1, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 2. - { - accum_data_v2 = _mm512_sllv_epi32(accum_data_v2, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v2, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v2, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v2 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v2 = _mm512_inserti32x8( - accum_data_v2, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 3. - { - accum_data_v3 = _mm512_sllv_epi32(accum_data_v3, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v3, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v3, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v3 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v3 = _mm512_inserti32x8( - accum_data_v3, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 4. - { - accum_data_v4 = _mm512_sllv_epi32(accum_data_v4, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v4, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v4, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v4 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v4 = _mm512_inserti32x8( - accum_data_v4, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 5. - { - accum_data_v5 = _mm512_sllv_epi32(accum_data_v5, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v5, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v5, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v5 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v5 = _mm512_inserti32x8( - accum_data_v5, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 6. - { - accum_data_v6 = _mm512_sllv_epi32(accum_data_v6, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v6, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v6, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v6 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v6 = _mm512_inserti32x8( - accum_data_v6, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 7. - { - accum_data_v7 = _mm512_sllv_epi32(accum_data_v7, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v7, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v7, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v7 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v7 = _mm512_inserti32x8( - accum_data_v7, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 8. - { - accum_data_v8 = _mm512_sllv_epi32(accum_data_v8, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v8, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v8, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v8 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v8 = _mm512_inserti32x8( - accum_data_v8, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 9. - { - accum_data_v9 = _mm512_sllv_epi32(accum_data_v9, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v9, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v9, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v9 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v9 = _mm512_inserti32x8( - accum_data_v9, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 10. - { - accum_data_va = _mm512_sllv_epi32(accum_data_va, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_va, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_va, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_va = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_va = _mm512_inserti32x8( - accum_data_va, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 11. - { - accum_data_vb = _mm512_sllv_epi32(accum_data_vb, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vb, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vb, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vb = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vb = _mm512_inserti32x8( - accum_data_vb, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 12. - { - accum_data_vc = _mm512_sllv_epi32(accum_data_vc, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vc, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vc, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vc = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vc = _mm512_inserti32x8( - accum_data_vc, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 13. - { - accum_data_vd = _mm512_sllv_epi32(accum_data_vd, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vd, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vd, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vd = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vd = _mm512_inserti32x8( - accum_data_vd, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 14. - { - accum_data_ve = _mm512_sllv_epi32(accum_data_ve, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_ve, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_ve, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_ve = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_ve = _mm512_inserti32x8( - accum_data_ve, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 15. - { - accum_data_vf = _mm512_sllv_epi32(accum_data_vf, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vf, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vf, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vf = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vf = _mm512_inserti32x8( - accum_data_vf, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - - if (params.dst_zero_point != 0) { - __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); - accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); - accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); - accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); - accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); - accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); - accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); - accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); - accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); - accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); - accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); - accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); - accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); - accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); - accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); - accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); - accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); - } - } - - const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); - const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); - - const bool store_full_block = - (residual_rows == 16) && (residual_cols == 16); - - __m512i accum_data_v[16]; - - // In most cases we would make this conditional on (!store_full_block) and - // unwind the clamp-and-store loop, but the benefit appears small. - { - accum_data_v[0] = accum_data_v0; - accum_data_v[1] = accum_data_v1; - accum_data_v[2] = accum_data_v2; - accum_data_v[3] = accum_data_v3; - accum_data_v[4] = accum_data_v4; - accum_data_v[5] = accum_data_v5; - accum_data_v[6] = accum_data_v6; - accum_data_v[7] = accum_data_v7; - accum_data_v[8] = accum_data_v8; - accum_data_v[9] = accum_data_v9; - accum_data_v[10] = accum_data_va; - accum_data_v[11] = accum_data_vb; - accum_data_v[12] = accum_data_vc; - accum_data_v[13] = accum_data_vd; - accum_data_v[14] = accum_data_ve; - accum_data_v[15] = accum_data_vf; - } - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < 16; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_storeu_epi8(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi8(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi8(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_storeu_epi8(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi8(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi8(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < 16; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_storeu_epi16(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi16(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi16(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - for (int j = 0; j < 16; ++j) { - _mm512_storeu_epi32(tmp_ptr + j * dst_stride, accum_data_v[j]); - } - } else { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, - accum_data_v[j]); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += 16 * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - 16 * params.dst_stride); - rhs_col_ptr += 16 * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[16]; - if (has_rhs_sums_offsets) { - const __m512i rhs_sums_offset_v = - _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), - _mm512_loadu_epi32(¶ms.rhs_sums[0])); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; row += 16) { - const int residual_rows = std::min(params.dst_rows - row, 16); - - __m512i accum_data_v0; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m512i lhs_sums_offset = - _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), - _mm512_loadu_epi32(¶ms.lhs_sums[row])); - initial_accum_data = - _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); - } - - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth != 0) { - initial_accum_data = _mm512_add_epi32(initial_accum_data, - _mm512_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm512_sub_epi32(initial_accum_data, - _mm512_set1_epi32(rhs_sums_offsets[0])); - } else { - accum_data_v0 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += 4) { - const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); - const __m128i rhs_data_8bit = _mm_loadu_epi8(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - // For simplicity we load 4x the data that we need and process twice the - // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); - - // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. - const __m512i lhs_16_bit_low = - _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); - // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. - const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( - _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); - - // Process column 0. - __m512i accum_v = accum_data_v0; - constexpr int index = 0; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v0 = accum_v; - - lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m512i m_vector; - __m512i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_fixedpoint[row]); - e_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); - } - - const __m512i m_64bit_low = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); - const __m512i m_64bit_high = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); - - const __m512i zero_vector = _mm512_setzero_epi32(); - const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); - const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); - const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); - const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 0)); - const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 1)); - - const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); - - // Shift and round column 0. - accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), - m_64bit_low); - __m512i scaled_v_high = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v0 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v0 = _mm512_inserti32x8( - accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - - if (params.dst_zero_point != 0) { - __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); - accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); - } - } - - const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); - const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_mask_storeu_epi16(tmp_ptr, row_mask, - _mm512_cvtepi32_epi16(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += 16 * params.lhs_stride; - } // End row-block loop. -} // NOLINT(readability/fn_size) - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 float"); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - const std::int64_t dst_stride = params.dst_stride >> 2; - const std::int64_t rhs_stride = params.rhs_stride >> 2; - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - const int end_row = std::min(params.dst_rows, params.last_row + 16); - const int end_col = std::min(params.dst_cols, params.last_col + 16); - - const float* adj_rhs_col_ptr = - params.rhs_base_ptr - params.start_col * rhs_stride; - float* adj_dst_col_ptr = - params.dst_base_ptr - params.start_col * dst_stride - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); - const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); - - int col = params.start_col; - for (; col <= end_col - 16; col += 16) { - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - int row = params.start_row; - for (; row <= end_row - 16; row += 16) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr); - - // Process block in two halves, split by columns. - { - constexpr int mmm = 0; - - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than - // first loading together and then extract with broadcasting. This is - // because AVX flavours and instrinsics and compilers in combination - // do not handle this pattern of extraction very well. - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, first iteration. - { - constexpr int mmm = 1; - - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, second iteration. - } // End row-block loop. - - // The unrolling within this conditional may be somewhat pointless. It - // depends on the kinds of models. - if (row < end_row) { - const int residual_rows = end_row - row; - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - // Process block in two halves, split by columns. - for (int mmm = 0; mmm < 2; ++mmm) { - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, - accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, - accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, - accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, - accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, - accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, - accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, - accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, - accum_data_v7); - } - } - } // Inner half-block loop. - } // Residual rows, main col-block loop. - } // End col-block loop. - - if (col < end_col) { - RUY_DCHECK_GE(end_col - col, 0); - RUY_DCHECK_LT(end_col - col, 16); - - __m512 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - for (int row = params.start_row; row < end_row; row += 16) { - const int residual_rows = std::min(end_row - row, 16); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - // Process block in two halves, split by columns. - for (int mmm = 0; mmm < 2; ++mmm) { - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 16; - rhs_ptr += 16; - } - - const int residual_cols = std::min(end_col - col - 8 * mmm, 8); - - if (residual_rows == 16) { - if (residual_cols == 8) { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_storeu_ps(block_ptr, accum_data_v[j]); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_storeu_ps(block_ptr, accum_data_v[j]); - } - } - } else { - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); - } - } - } // Inner half-block loop. - } // End row-block loop. - } // Residual cols. -} - -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 float GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - const int end_row = std::min(params.dst_rows, params.last_row + 16); - - float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); - const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); - - __m512 accum_data_v; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = adj_dst_col_ptr; - - int row = params.start_row; - for (; row <= end_row - 16; row += 16) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm512_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float rhs_data = *rhs_ptr; - - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); - accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 16; - rhs_ptr += 16; - } - - accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); - _mm512_storeu_ps(dst_ptr, accum_data_v); - } // End row-block loop. - - if (row < end_row) { - const int residual_rows = end_row - row; - RUY_CHECK_GE(residual_rows, 1); - RUY_CHECK_LT(residual_rows, 16); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float rhs_data = *rhs_ptr; - - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); - accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 16; - rhs_ptr += 16; - } - - accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); - _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); - } // End handling of residual rows. -} - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc deleted file mode 100644 index c868c00957b..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_avxvnni.cc +++ /dev/null @@ -1,435 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvxFloatBlockSize = 16; -static constexpr int kAvx8bitBlockSize = 16; -static constexpr int kAvx8bitInnerSize = 4; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvxVnni 8-bit (UNFINISHED)"); - - std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - // Initialize with bias. - std::int32_t initial_accum_data[kAvx8bitBlockSize]; - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - initial_accum_data[i] = 0; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; - rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; - } - } - } - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.rhs_zero_point * params.lhs_sums[row + i]; - } - } - } - if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.lhs_zero_point * params.rhs_sums[col + j]; - } - } - } - if (params.lhs_zero_point && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.prod_zp_depth; - } - } - } - - if (params.dst_type_id != DstTypeId::kValue) { - std::int32_t m_vector[kAvx8bitBlockSize]; - std::int32_t e_vector[kAvx8bitBlockSize]; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - int i = 0; - for (; i < residual_rows; ++i) { - m_vector[i] = params.multiplier_fixedpoint[row + i]; - e_vector[i] = params.multiplier_exponent[row + i]; - } - for (; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = m_vector[0]; - e_vector[i] = e_vector[0]; - } - } else { - // These arrays have size LhsCols, and are pre-filled. - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = params.multiplier_fixedpoint[i]; - e_vector[i] = params.multiplier_exponent[i]; - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = MultiplyByQuantizedMultiplier( - accum_data[j][i], m_vector[i], e_vector[i]); - } - } - - if (params.dst_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.dst_zero_point; - } - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - } - - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = - store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride / sizeof(std::int8_t) - : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::int8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast( - params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::uint8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int16_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int16_t* tmp_ptr = const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - const std::int16_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - std::int16_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = block_ptr[i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int16_t); - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int32_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = accum_data[j][i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int32_t); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvxVnni float (UNFINISHED)"); - - float lhs_data[kAvxFloatBlockSize]; - float rhs_data[kAvxFloatBlockSize]; - float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = params.dst_base_ptr; - const float* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvxFloatBlockSize) { - const float* lhs_col_ptr = params.lhs_base_ptr; - float* dst_ptr = dst_col_ptr; - const float* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvxFloatBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvxFloatBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvxFloatBlockSize); - - // Initialize with bias. - float initial_accum_data[kAvxFloatBlockSize]; - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - initial_accum_data[i] = 0.0f; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - lhs_data[i] = lhs_ptr[i]; - rhs_data[i] = rhs_ptr[i]; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] += lhs_data[i] * rhs_data[j]; - } - } - - lhs_ptr += kAvxFloatBlockSize; - rhs_ptr += kAvxFloatBlockSize; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - - const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && - (residual_cols == kAvxFloatBlockSize); - - { - float* block_ptr = - store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); - const int block_col_offset = store_full_block - ? params.dst_stride / sizeof(float) - : kAvxFloatBlockSize; - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - block_ptr[i] = accum_data[j][i]; - } - block_ptr += block_col_offset; - } - } - if (!store_full_block) { - const float* block_ptr = params.dst_tmp_buf; - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; - } - block_ptr += kAvxFloatBlockSize; - } - } - - lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); - dst_ptr += kAvxFloatBlockSize; - } // End row-block loop. - - dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); - rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); - } // End col-block loop. -} - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_common.h b/tensorflow/lite/experimental/ruy/ruy/kernel_common.h deleted file mode 100644 index c1721b81869..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_common.h +++ /dev/null @@ -1,481 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -template -struct Kernel {}; - -template -void RunKernelTyped(Tuning tuning, const PackedMatrix& lhs, - const PackedMatrix& rhs, const Spec& spec, - int start_row, int start_col, int end_row, int end_col, - Matrix* dst) { - using Kernel = Kernel; - Kernel kernel(tuning); -#if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) - using LhsLayout = typename Kernel::LhsLayout; - using RhsLayout = typename Kernel::RhsLayout; -#endif - // end_row and end_col may be larger than dst dimensions. - // that is because kernels write directly to the destination matrix, whose - // dimensions may not be a multiple of the kernel dimensions, and we try to - // keep this annoyance localized as an implementation detail in kernels, - // by allowing to pass rounded-up values down as far as possible. - // These assertions encode the contract. - RUY_DCHECK_LE(0, start_row); - RUY_DCHECK_LE(start_row, end_row); - RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); - RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); - RUY_DCHECK_LE(0, start_col); - RUY_DCHECK_LE(start_col, end_col); - RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); - RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); -#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) - kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst); -#else - for (int col = start_col; col < end_col; col += RhsLayout::kCols) { - int block_end_col = std::min(col + RhsLayout::kCols, end_col); - for (int row = start_row; row < end_row; row += LhsLayout::kCols) { - int block_end_row = std::min(row + LhsLayout::kCols, end_row); - kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst); - } - } -#endif -} - -// Main entry point for kernels. -template -void RunKernel(Tuning tuning, const SidePair& src, void* spec, - const SidePair& start, const SidePair& end, - DMatrix* dst) { - Matrix mdst = ToMatrix(*dst); - RunKernelTyped( - tuning, ToPackedMatrix(src[Side::kLhs]), - ToPackedMatrix(src[Side::kRhs]), - *static_cast(spec), start[Side::kLhs], start[Side::kRhs], - end[Side::kLhs], end[Side::kRhs], &mdst); -} - -// Copied from gemmlowp/fixedpoint. -inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, - std::int32_t b) { - bool overflow = a == b && a == std::numeric_limits::min(); - std::int64_t a_64(a); - std::int64_t b_64(b); - std::int64_t ab_64 = a_64 * b_64; - std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); - std::int32_t ab_x2_high32 = - static_cast((ab_64 + nudge) / (1ll << 31)); - return overflow ? std::numeric_limits::max() : ab_x2_high32; -} - -inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) { - std::int32_t sign = numerator >= 0 ? 1 : -1; - std::int32_t abs_numerator = std::abs(numerator); - std::int32_t mask = (1LL << exponent) - 1; - std::int32_t remainder = abs_numerator & mask; - std::int32_t threshold = mask >> 1; - std::int32_t abs_result = - (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0); - return sign * abs_result; -} - -// Copied from TF Lite code. -inline std::int32_t MultiplyByQuantizedMultiplier( - std::int32_t x, std::int32_t quantized_multiplier, int shift) { - int left_shift = shift > 0 ? shift : 0; - int right_shift = shift > 0 ? 0 : -shift; - return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - x * (1 << left_shift), quantized_multiplier), - right_shift); -} - -// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar -// is int32 (i.e. in all cases except floating-point) and if the destination is -// not int32 (i.e. unless the user wants to get raw accumulators). -template ::value && - !std::is_same::value> -struct ApplyMultiplierImpl {}; - -// Specialization in non-applicable case: do nothing, just check that values -// are default. -template -struct ApplyMultiplierImpl { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - static void Run(const Spec& spec, int row, AccumScalar* accum) { - RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_DCHECK_EQ(spec.multiplier_exponent, 0); - } -}; - -template -struct ApplyMultiplierImpl { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - static void Run(const Spec& spec, int row, AccumScalar* accum) { - AccumScalar m = spec.multiplier_fixedpoint_perchannel - ? spec.multiplier_fixedpoint_perchannel[row] - : spec.multiplier_fixedpoint; - int e = spec.multiplier_exponent_perchannel - ? spec.multiplier_exponent_perchannel[row] - : spec.multiplier_exponent; - *accum = MultiplyByQuantizedMultiplier(*accum, m, e); - } -}; - -template -void ApplyMultiplier(const Spec& spec, int row, - typename Spec::AccumScalar* accum) { - ApplyMultiplierImpl::Run(spec, row, accum); -} - -template -struct Kernel { - using AccumScalar = typename Spec::AccumScalar; - using LhsLayout = typename Spec::StandardCppKernelLhsLayout; - using RhsLayout = typename Spec::StandardCppKernelRhsLayout; - explicit Kernel(Tuning) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, const Spec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - // See the comment in RunKernelTyped. end_row may be larger than - // dst->layout.rows. It's the responsibility of the kernel to avoid - // overrunning dst boundaries, which we do here by computing - // clamped_end_row. - int clamped_end_row = std::min(end_row, dst->layout.rows); - int clamped_end_col = std::min(end_col, dst->layout.cols); - RUY_DCHECK_LE(0, start_row); - RUY_DCHECK_LE(start_row, clamped_end_row); - RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); - RUY_DCHECK_LE(clamped_end_row, end_row); - RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); - RUY_DCHECK_LE(0, start_col); - RUY_DCHECK_LE(start_col, clamped_end_col); - RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); - RUY_DCHECK_LE(clamped_end_col, end_col); - RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); - profiler::ScopeLabel label("Kernel (Standard Cpp)"); - const int depth = lhs.layout.rows; - for (int i = start_row; i < clamped_end_row; i++) { - for (int j = start_col; j < clamped_end_col; j++) { - using AccumScalar = typename Spec::AccumScalar; - AccumScalar accum = 0; - for (int k = 0; k < depth; k++) { - AccumScalar lhs_val = Element(lhs, k, i); - AccumScalar rhs_val = Element(rhs, k, j); - accum += lhs_val * rhs_val; - } - if (spec.bias) { - accum += spec.bias[i]; - } - if (lhs.zero_point) { - accum -= lhs.zero_point * rhs.sums[j]; - } - if (rhs.zero_point) { - accum -= rhs.zero_point * lhs.sums[i]; - } - if (lhs.zero_point && rhs.zero_point) { - accum += lhs.zero_point * rhs.zero_point * depth; - } - ApplyMultiplier(spec, i, &accum); - accum += dst->zero_point; - accum = std::min(accum, spec.clamp_max); - accum = std::max(accum, spec.clamp_min); - *ElementPtr(dst, i, j) = static_cast(accum); - } - } - } -}; - -#define RUY_INHERIT_KERNEL(PARENT, CHILD) \ - template \ - struct Kernel \ - : Kernel { \ - explicit Kernel(Tuning tuning) \ - : Kernel(tuning) {} \ - }; - -#if RUY_PLATFORM(NEON) -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) -RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) -#elif RUY_PLATFORM(X86) -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42) -RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2) -RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512) -RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni) -#endif - -// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. -// -// In other cases, we still define (empty) versions, so that dummy kernels -// can use the classes in function signatures. -#if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM)) || \ - RUY_PLATFORM(X86) - -#define RUY_ASM_FLAG_HAS_BIAS 0x1 -#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 -#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 -#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 -#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 - -#define RUY_ASM_TYPE_ID_UINT8 1 -#define RUY_ASM_TYPE_ID_INT8 2 -#define RUY_ASM_TYPE_ID_INT16 3 -#define RUY_ASM_TYPE_ID_INT32 4 - -template -struct DstTypeId {}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; -}; - -template -struct KernelParams8bit { - static constexpr int kMaxDstTypeSize = 4; - - const std::int32_t* bias; - const std::int32_t* lhs_sums; - const std::int32_t* rhs_sums; - const std::int8_t* lhs_base_ptr; - const std::int32_t* multiplier_fixedpoint; - const std::int32_t* multiplier_exponent; - const std::int8_t* rhs_base_ptr; - void* dst_base_ptr; - std::int32_t lhs_zero_point; - std::int32_t rhs_zero_point; - std::int32_t dst_zero_point; - std::int32_t prod_zp_depth; - std::int32_t start_row; - std::int32_t start_col; - std::int32_t last_row; - std::int32_t last_col; - std::int32_t dst_rows; - std::int32_t dst_cols; - std::int32_t lhs_stride; - std::int32_t rhs_stride; - std::int32_t dst_stride; - std::int32_t depth; - std::int32_t clamp_min; - std::int32_t clamp_max; - std::uint8_t flags; - std::uint8_t dst_type_id; - const std::int32_t zero_data[LhsCols] = {0}; - std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; - std::int32_t multiplier_fixedpoint_buf[LhsCols]; - std::int32_t multiplier_exponent_buf[LhsCols]; -}; - -template -void MakeKernelParams8bit(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, - int start_row, int start_col, int end_row, - int end_col, Matrix* dst, - KernelParams8bit* params) { - using Params = KernelParams8bit; - - static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); - - const int depth = lhs.layout.rows; - RUY_DCHECK_EQ(start_row % LhsCols, 0); - RUY_DCHECK_EQ(start_col % RhsCols, 0); - RUY_DCHECK_EQ(end_row % LhsCols, 0); - RUY_DCHECK_EQ(end_col % RhsCols, 0); - - params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; - params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; - params->flags = 0; - params->bias = params->zero_data; - if (spec.bias) { - params->bias = spec.bias; - params->flags |= RUY_ASM_FLAG_HAS_BIAS; - } - if (lhs.sums) { - params->lhs_sums = lhs.sums; - params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; - } - if (rhs.sums) { - params->rhs_sums = rhs.sums; - params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; - } - params->start_row = start_row; - params->start_col = start_col; - params->last_row = end_row - LhsCols; - params->last_col = end_col - RhsCols; - params->lhs_stride = lhs.layout.stride; - params->rhs_stride = rhs.layout.stride; - params->dst_stride = sizeof(DstScalar) * dst->layout.stride; - params->lhs_zero_point = lhs.zero_point; - params->rhs_zero_point = rhs.zero_point; - params->dst_zero_point = dst->zero_point; - params->depth = depth; - params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; - if (spec.multiplier_fixedpoint_perchannel) { - params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; - params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; - params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel; - params->multiplier_exponent = spec.multiplier_exponent_perchannel; - } else { - if (spec.multiplier_exponent > 0) { - params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; - } - params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; - params->multiplier_exponent = params->multiplier_exponent_buf; - for (int i = 0; i < LhsCols; i++) { - params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint; - params->multiplier_exponent_buf[i] = spec.multiplier_exponent; - } - } - params->clamp_min = spec.clamp_min; - params->clamp_max = spec.clamp_max; - params->dst_rows = dst->layout.rows; - params->dst_cols = dst->layout.cols; - - RUY_DCHECK_LT(params->last_row, params->dst_rows); - RUY_DCHECK_LT(params->last_col, params->dst_cols); - - params->dst_type_id = DstTypeId::kValue; - params->dst_base_ptr = - dst->data.get() + start_col * dst->layout.stride + start_row; -} - -template -struct KernelParamsFloat { - const float* lhs_base_ptr; - const float* rhs_base_ptr; - float* dst_base_ptr; - const float* bias; - std::int32_t start_row; - std::int32_t start_col; - std::int32_t last_row; - std::int32_t last_col; - std::int32_t dst_rows; - std::int32_t dst_cols; - std::int32_t lhs_stride; - std::int32_t rhs_stride; - std::int32_t dst_stride; - std::int32_t depth; - float clamp_min; - float clamp_max; - std::uint8_t flags; - const float zero_data[LhsCols] = {0}; - float dst_tmp_buf[LhsCols * RhsCols]; -}; - -template -inline void MakeKernelParamsFloat(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, - int start_row, int start_col, int end_row, - int end_col, Matrix* dst, - KernelParamsFloat* params) { - const int depth = lhs.layout.rows; - RUY_DCHECK_EQ(start_row % LhsCols, 0); - RUY_DCHECK_EQ(start_col % RhsCols, 0); - RUY_DCHECK_EQ(end_row % LhsCols, 0); - RUY_DCHECK_EQ(end_col % RhsCols, 0); - - params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; - params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; - params->dst_base_ptr = - dst->data.get() + start_col * dst->layout.stride + start_row; - - std::uint8_t flags = 0; - params->bias = params->zero_data; - if (spec.bias) { - params->bias = spec.bias; - flags |= RUY_ASM_FLAG_HAS_BIAS; - } - params->flags = flags; - params->start_row = start_row; - params->start_col = start_col; - params->last_row = end_row - LhsCols; - params->last_col = end_col - RhsCols; - params->lhs_stride = sizeof(float) * lhs.layout.stride; - params->rhs_stride = sizeof(float) * rhs.layout.stride; - params->dst_stride = sizeof(float) * dst->layout.stride; - params->depth = depth; - params->clamp_min = spec.clamp_min; - params->clamp_max = spec.clamp_max; - params->dst_rows = dst->layout.rows; - params->dst_cols = dst->layout.cols; - - RUY_DCHECK_LT(params->last_row, params->dst_rows); - RUY_DCHECK_LT(params->last_col, params->dst_cols); -} - -#else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) - -template -struct KernelParams8bit {}; - -template -struct KernelParamsFloat {}; - -#endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc deleted file mode 100644 index 46a6d045e6a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_sse42.cc +++ /dev/null @@ -1,428 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvxFloatBlockSize = 8; -static constexpr int kAvx8bitBlockSize = 8; -static constexpr int kAvx8bitInnerSize = 4; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kSse42 8-bit (UNFINISHED)"); - std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - // Initialize with bias. - std::int32_t initial_accum_data[kAvx8bitBlockSize]; - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - initial_accum_data[i] = 0; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; - rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; - } - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; - } - } - } - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.rhs_zero_point * params.lhs_sums[row + i]; - } - } - } - if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.lhs_zero_point * params.rhs_sums[col + j]; - } - } - } - if (params.lhs_zero_point && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.prod_zp_depth; - } - } - } - - if (params.dst_type_id != DstTypeId::kValue) { - std::int32_t m_vector[kAvx8bitBlockSize]; - std::int32_t e_vector[kAvx8bitBlockSize]; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - int i = 0; - for (; i < residual_rows; ++i) { - m_vector[i] = params.multiplier_fixedpoint[row + i]; - e_vector[i] = params.multiplier_exponent[row + i]; - } - for (; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = m_vector[0]; - e_vector[i] = e_vector[0]; - } - } else { - // These arrays have size LhsCols, and are pre-filled. - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = params.multiplier_fixedpoint[i]; - e_vector[i] = params.multiplier_exponent[i]; - } - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = MultiplyByQuantizedMultiplier( - accum_data[j][i], m_vector[i], e_vector[i]); - } - } - - if (params.dst_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.dst_zero_point; - } - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - } - - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = - store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride / sizeof(std::int8_t) - : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::int8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast( - params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::uint8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int16_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int16_t* tmp_ptr = const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - const std::int16_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - std::int16_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = block_ptr[i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int16_t); - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int32_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = accum_data[j][i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int32_t); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kSse42 float (UNFINISHED)"); - - float lhs_data[kAvxFloatBlockSize]; - float rhs_data[kAvxFloatBlockSize]; - float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = params.dst_base_ptr; - const float* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvxFloatBlockSize) { - const float* lhs_col_ptr = params.lhs_base_ptr; - float* dst_ptr = dst_col_ptr; - const float* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvxFloatBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvxFloatBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvxFloatBlockSize); - - // Initialize with bias. - float initial_accum_data[kAvxFloatBlockSize]; - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - initial_accum_data[i] = 0.0f; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - lhs_data[i] = lhs_ptr[i]; - rhs_data[i] = rhs_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] += lhs_data[i] * rhs_data[j]; - } - } - lhs_ptr += kAvxFloatBlockSize; - rhs_ptr += kAvxFloatBlockSize; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - - const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && - (residual_cols == kAvxFloatBlockSize); - - { - float* block_ptr = - store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); - const int block_col_offset = store_full_block - ? params.dst_stride / sizeof(float) - : kAvxFloatBlockSize; - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - block_ptr[i] = accum_data[j][i]; - } - block_ptr += block_col_offset; - } - } - if (!store_full_block) { - const float* block_ptr = params.dst_tmp_buf; - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; - } - block_ptr += kAvxFloatBlockSize; - } - } - - lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); - dst_ptr += kAvxFloatBlockSize; - } // End row-block loop. - - dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); - rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); - } // End col-block loop. -} - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h b/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h deleted file mode 100644 index f79f70ab88c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/kernel_x86.h +++ /dev/null @@ -1,222 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - Kernel8bitSse42(params); - } -}; - -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - KernelFloatSse42(params); - } -}; - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitAvx512SingleCol(params); - } else { - Kernel8bitAvx512(params); - } - } -}; - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (dst->layout.cols == 1) { - KernelFloatAvx512SingleCol(params); - } else { - KernelFloatAvx512(params); - } - } -}; - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitAvx2SingleCol(params); - } else { - Kernel8bitAvx2(params); - } - } -}; - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (dst->layout.cols == 1) { - KernelFloatAvx2SingleCol(params); - } else { - KernelFloatAvx2(params); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - Kernel8bitAvxVnni(params); - } -}; - -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - KernelFloatAvxVnni(params); - } -}; - -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/matrix.h b/tensorflow/lite/experimental/ruy/ruy/matrix.h deleted file mode 100644 index a76f32136c6..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/matrix.h +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ - -#include -#include // IWYU pragma: keep -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -namespace ruy { - -// Layout storage order. Here and elsewhere, 'col' is short for 'column'. -// 'column-major' means that each column is contiguous in memory. -enum class Order : std::uint8_t { kColMajor, kRowMajor }; - -// Describes the shape and storage layout of a matrix. -struct Layout final { - std::int32_t rows = 0; - std::int32_t cols = 0; - // Stride is the offset between two adjacent matrix elements - // in the non-contiguous direction. - std::int32_t stride = 0; - Order order = Order::kColMajor; -}; - -namespace detail { - -// Thin wrapper around a pointer that tracks its constness dynamically. -// -// This is our take on the C++ problem of enforcing constness of data -// wrapped in a containers class: it's not worth the hassle of trying to -// make it fully work at compile-time. -// Instead, we only enforce constness at runtime, and to make it -// zero-overhead, we only enforce it in debug builds. -template -class ConstCheckingPtr final { - public: - using element_type = T; - - // Convenience methods. Most `set` calls go through these. - ConstCheckingPtr& operator=(T* ptr) { - set(ptr); - return *this; - } - ConstCheckingPtr& operator=(const T* ptr) { - set(ptr); - return *this; - } - ConstCheckingPtr& operator=(std::nullptr_t) { - set(static_cast(nullptr)); - return *this; - } - - // Core accessors. These encapsulate the main logic: - // - for `set`, the constness of the argument determines whether internal - // pointer should be tracked as const/mutable. - // - for `get`, the constness of `this` determines whether the call - // counts as a const or mutable use of the internal pointer. - void set(T* ptr) { - ptr_ = ptr; - set_mutable(true); - } - void set(const T* ptr) { - ptr_ = ptr; - set_mutable(false); - } - T* get() /* NOT const */ { - assert_mutable(); - return const_cast(ptr_); - } - const T* get() const { return ptr_; } - - private: - static_assert(!std::is_const::value, ""); - const T* ptr_ = nullptr; -#ifndef NDEBUG - bool is_mutable_ = true; - void set_mutable(bool val) { is_mutable_ = val; } - void assert_mutable() { RUY_DCHECK(is_mutable_); } -#else - void set_mutable(bool) {} - void assert_mutable() {} -#endif -}; - -} // namespace detail - -// A Matrix is really what Eigen and gemmlowp would have called a 'matrix map': -// it merely wraps existing data as a matrix. It doesn't own any buffer. -// Scalar may be any floating-point or integral type. When integral, it may be -// signed or unsigned. -template -struct Matrix final { - Matrix& operator=(const Matrix& other) { - data = other.data; - cacheable = other.cacheable; - layout = other.layout; - zero_point = other.zero_point; - return *this; - } - - // The underlying buffer wrapped by this matrix. - detail::ConstCheckingPtr data; - // The shape and data layout of this matrix. - Layout layout; - // The zero_point, i.e. which Scalar value is to be interpreted as zero. - // When Scalar is floating-point, this must be 0. - Scalar zero_point = 0; - // Clients of Ruy must set this flag to enable any caching behavior. Doesn't - // impact numerical results, but caching can impact observable metrics like - // latency, memory usage, power, etc. - bool cacheable = false; -}; - -inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { - layout->rows = rows; - layout->cols = cols; - layout->order = order; - layout->stride = order == Order::kColMajor ? rows : cols; -} - -// Opaque data structure representing a pre-packed matrix, as obtained from -// Ruy's advanced API. -struct PrepackedMatrix { - void* data = nullptr; - std::size_t data_size = 0; - void* sums = nullptr; - std::size_t sums_size = 0; -}; - -template -StreamType& operator<<(StreamType& stream, const Matrix& mat) { - for (int row = 0; row < mat.layout.rows; row++) { - for (int col = 0; col < mat.layout.cols; col++) { - stream << static_cast(Element(mat, row, col)) << " "; - } - stream << "\n"; - } - return stream; -} - -// Compile-time version of KernelLayout, used to declare kernel layouts in a -// way that can be consumed by compile-time logic. -// See how partial specializations of Kernel use it to declare their layouts. -// The only reason why this is currently part of the public API is to -// allow testing various layouts for the Path::kStandardCpp kernel, as a -// testing-only feature. See Spec::StandardCppKernelLhsLayout. -template -struct FixedKernelLayout { - static constexpr Order kOrder = tOrder; - static constexpr int kRows = tRows; - static constexpr int kCols = tCols; -}; - -#if (__cplusplus < 201703L) -// A static constexpr data member is automatically inline and should not require -// redeclaration without an initializer. This is actually deprecated from C++17 -// onwards. Clang with -O0 without this can fail to link. -template -constexpr int FixedKernelLayout::kCols; -template -constexpr int FixedKernelLayout::kRows; -#endif - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/opt_set.h b/tensorflow/lite/experimental/ruy/ruy/opt_set.h deleted file mode 100644 index fef0107ed01..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/opt_set.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ - -// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling -// certain optimizations. It should be used by defining that macro on the -// compiler command line. -// -// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy. -#define RUY_OPT_INTRINSICS 0x1 -#define RUY_OPT_ASM 0x2 -#define RUY_OPT_TUNING 0x4 -#define RUY_OPT_FAT_KERNEL 0x8 -#define RUY_OPT_NATIVE_ROUNDING 0x10 -#define RUY_OPT_AVOID_ALIASING 0x20 -#define RUY_OPT_MAX_STREAMING 0x40 -#define RUY_OPT_PACK_AHEAD 0x80 -#define RUY_OPT_PREFETCH_LOAD 0x100 -#define RUY_OPT_PREFETCH_STORE 0x200 -#define RUY_OPT_FRACTAL_Z 0x400 -#define RUY_OPT_FRACTAL_U 0x800 -#define RUY_OPT_FRACTAL_HILBERT 0x1000 - -#if !defined(RUY_OPT_SET) -#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK -// Load prefetching is detrimental in matrix multiplication benchmarks. -// Store prefetching is not. -#define RUY_OPT_SET (~RUY_OPT_PREFETCH_LOAD) -#else -// Default to all optimizations. -#define RUY_OPT_SET (~0) -#endif -#endif - -#define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0) - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack.h b/tensorflow/lite/experimental/ruy/ruy/pack.h deleted file mode 100644 index 96040aa1039..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -// IWYU pragma: begin_exports -#if RUY_PLATFORM(NEON) -#include "tensorflow/lite/experimental/ruy/ruy/pack_arm.h" -#elif RUY_PLATFORM(X86) -#include "tensorflow/lite/experimental/ruy/ruy/pack_x86.h" -#else -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#endif -// IWYU pragma: end_exports - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc b/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc deleted file mode 100644 index 52b55a57cc6..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_arm.cc +++ /dev/null @@ -1,1936 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - - "add w1, w1, #16\n" - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "cmp w1, w2\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "beq 2f\n" - - "1:\n" - - "add w1, w1, #16\n" - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "cmp w1, w2\n" - "sadalp v29.4s, v17.8h\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "bne 1b\n" - - "2:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "add %[packed_ptr], %[packed_ptr], #64\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "saddlp v17.8h, v5.16b\n" - "saddlp v18.8h, v6.16b\n" - "saddlp v19.8h, v7.16b\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "str q4, [%[packed_ptr], #0]\n" - "str q5, [%[packed_ptr], #16]\n" - "str q6, [%[packed_ptr], #32]\n" - "str q7, [%[packed_ptr], #48]\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - - "4:\n" - - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - "addp v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point), - [ input_xor ] "r"(input_xor) - : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", - "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "v28", "v29", "v30", "v31"); -} -#endif - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_OFFSET_SRC_PTR0 0 -#define RUY_OFFSET_SRC_PTR1 4 -#define RUY_OFFSET_SRC_PTR2 8 -#define RUY_OFFSET_SRC_PTR3 12 -#define RUY_OFFSET_SUMS_PTR 16 -#define RUY_OFFSET_PACKED_PTR 20 -#define RUY_OFFSET_SRC_INC0 24 -#define RUY_OFFSET_SRC_INC1 28 -#define RUY_OFFSET_SRC_INC2 32 -#define RUY_OFFSET_SRC_INC3 36 -#define RUY_OFFSET_SRC_ROWS 40 -#define RUY_OFFSET_SRC_ZERO_POINT 44 -#define RUY_OFFSET_INPUT_XOR 48 - -template -void CheckOffsetsInPackParams8bit(const Params&) { - static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, ""); - static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, ""); - static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, ""); - static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, ""); - static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, ""); - static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, ""); - static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, ""); - static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, ""); - static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, ""); - static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, ""); - static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, ""); - static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT, - ""); - static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, ""); -} - -// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. -// No attempt made at making this code efficient on in-order cores yet. -void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params) { - CheckOffsetsInPackParams8bit(params); - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - const void* src_ptr0 = params.src_ptr0; - const void* src_ptr1 = params.src_ptr1; - const void* src_ptr2 = params.src_ptr2; - const void* src_ptr3 = params.src_ptr3; - const int src_inc0 = params.src_inc0; - const int src_inc1 = params.src_inc1; - const int src_inc2 = params.src_inc2; - const int src_inc3 = params.src_inc3; - const std::int8_t* packed_ptr = params.packed_ptr; - - asm volatile( - // clang-format off - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" - "vdup.8 q11, r2\n" - "mov r1, #0\n" - // Zero-out the accumulators - "vmov.i32 q12, #0\n" - "vmov.i32 q13, #0\n" - "vmov.i32 q14, #0\n" - "vmov.i32 q15, #0\n" - - // Round down src_rows to nearest multiple of 16. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "and r2, r3, #-16\n" - "cmp r1, r2\n" - "beq 3f\n" - - "1:\n" - "add r1, r1, #16\n" - /* Load q0 */ - "vld1.8 {d0, d1}, [%[src_ptr0]]\n" - "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n") - - /* Load q1 */ - "vld1.8 {d2, d3}, [%[src_ptr1]]\n" - "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n") - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - // Pairwise add in to 16b accumulators. - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q12 and q13 contain 4x32b accumulators - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - // Now do the same for src_ptr2 and src_ptr3. - "vld1.8 {d0, d1}, [%[src_ptr2]]\n" - "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n") - - "vld1.8 {d2, d3}, [%[src_ptr3]]\n" - "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n") - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q14 and q15 contain 4x32b accumulators - "vpadal.s16 q14, q8\n" - "vpadal.s16 q15, q9\n" - - "cmp r1, r2\n" - "bne 1b\n" - - "3:\n" - - // Now pack the last (num_rows % 16) rows. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "ands r2, r3, #15\n" - "beq 4f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// First, read/accumulate/write for src_ptr0 and src_ptr1. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Reset to src_zero for src_ptr2 and src_ptr3. - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// Next, read/accumulate/write for src_ptr2 and src_ptr3. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q14, q8\n" - "vpadal.s16 q15, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - "4:\n" - // Pairwise add 32-bit accumulators - "vpadd.i32 d24, d24, d25\n" - "vpadd.i32 d26, d26, d27\n" - "vpadd.i32 d28, d28, d29\n" - "vpadd.i32 d30, d30, d31\n" - // Final 32-bit values per row - "vpadd.i32 d25, d24, d26\n" - "vpadd.i32 d27, d28, d30\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" - "cmp r3, #0\n" - "beq 6f\n" - "vst1.32 {d25}, [r3]!\n" - "vst1.32 {d27}, [r3]!\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3) - : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), - [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3), - [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) - : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); -} - -// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. -// No attempt made at making this code efficient on in-order cores yet. -// This version differs from the above in that we only handle two columns -// at a time. -void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params) { - CheckOffsetsInPackParams8bit(params); - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - const void* src_ptr0 = params.src_ptr0; - const void* src_ptr1 = params.src_ptr1; - const int src_inc0 = params.src_inc0; - const int src_inc1 = params.src_inc1; - const std::int8_t* packed_ptr = params.packed_ptr; - - asm volatile( - // clang-format off - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" - "vdup.8 q11, r2\n" - "mov r1, #0\n" - // Zero-out the accumulators - "vmov.i32 q12, #0\n" - "vmov.i32 q13, #0\n" - - // Round down src_rows to nearest multiple of 16. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "and r2, r3, #-16\n" - "cmp r1, r2\n" - "beq 3f\n" - - "1:\n" - "add r1, r1, #16\n" - /* Load q0 */ - "vld1.8 {d0, d1}, [%[src_ptr0]]\n" - "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" - - /* Load q1 */ - "vld1.8 {d2, d3}, [%[src_ptr1]]\n" - "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - // Pairwise add in to 16b accumulators. - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q12 and q13 contain 4x32b accumulators - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "cmp r1, r2\n" - - "bne 1b\n" - - "3:\n" - - // Now pack the last (num_rows % 16) rows. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "ands r2, r3, #15\n" - "beq 4f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// Read/accumulate/write for src_ptr0 and src_ptr1. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - "4:\n" - - // Pairwise add 32-bit accumulators - "vpadd.i32 d24, d24, d25\n" - "vpadd.i32 d26, d26, d27\n" - // Final 32-bit values per row - "vpadd.i32 d25, d24, d26\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" - "cmp r3, #0\n" - "beq 6f\n" - "vst1.32 {d25}, [r3]!\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1) - : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), - [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) - : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); -} - -#undef RUY_OFFSET_SRC_PTR0 -#undef RUY_OFFSET_SRC_PTR1 -#undef RUY_OFFSET_SRC_PTR2 -#undef RUY_OFFSET_SRC_PTR32 -#undef RUY_OFFSET_SUMS_PTR -#undef RUY_OFFSET_PACKED_PTR0 -#undef RUY_OFFSET_SRC_INC0 -#undef RUY_OFFSET_SRC_INC1 -#undef RUY_OFFSET_SRC_INC2 -#undef RUY_OFFSET_SRC_INC3 -#undef RUY_OFFSET_SRC_ROWS -#undef RUY_OFFSET_SRC_ZERO_POINT -#undef RUY_OFFSET_INPUT_XOR - -#endif // RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "ldr x13, [%[src_ptr3], #8]\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #16\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #16\n" - "ins v0.d[1], x10\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ins v1.d[1], x11\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ins v2.d[1], x12\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ins v3.d[1], x13\n" - "ldr x13, [%[src_ptr3], #8]\n" - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "cmp w1, w2\n" - "sadalp v29.4s, v17.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "add %[packed_ptr], %[packed_ptr], #64\n" - "sadalp v30.4s, v18.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "sadalp v31.4s, v19.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "bne 1b\n" - - "2:\n" - "ins v0.d[1], x10\n" - "ins v1.d[1], x11\n" - "ins v2.d[1], x12\n" - "ins v3.d[1], x13\n" - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "add %[packed_ptr], %[packed_ptr], #64\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "saddlp v17.8h, v5.16b\n" - "saddlp v18.8h, v6.16b\n" - "saddlp v19.8h, v7.16b\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "str q4, [%[packed_ptr], #0]\n" - "str q5, [%[packed_ptr], #16]\n" - "str q6, [%[packed_ptr], #32]\n" - "str q7, [%[packed_ptr], #48]\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - - "4:\n" - - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - "addp v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), - [ src_zero_point ] "r"(src_zero_point), - [input_xor] "r"(input_xor) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", - "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, - int end_col, std::int32_t* sums_ptr, - int input_xor) { - profiler::ScopeLabel label( - "Pack (kNeonDotprod, optimized for in-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #1\n" - "dup v27.16b, w1\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "ldr x13, [%[src_ptr3], #8]\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #16\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #16\n" - "ins v0.d[1], x10\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ins v1.d[1], x11\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ins v2.d[1], x12\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ins v3.d[1], x13\n" - "ldr x13, [%[src_ptr3], #8]\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "trn2 v17.4s, v4.4s, v5.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "trn1 v18.4s, v6.4s, v7.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "trn2 v19.4s, v6.4s, v7.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - "ins v0.d[1], x10\n" - "ins v1.d[1], x11\n" - "ins v2.d[1], x12\n" - "ins v3.d[1], x13\n" - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - "cmp w2, #4\n" - "ble 4f\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - "cmp w2, #8\n" - "ble 4f\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - "cmp w2, #12\n" - "ble 4f\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "4:\n" - - "add v28.4s, v28.4s, v29.4s\n" - "add v30.4s, v30.4s, v31.4s\n" - "add v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), - [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), - [rows] "r"(src_rows), - [src_zero_point] "r"(static_cast(src_zero_point)), - [input_xor] "r"(input_xor) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, - int src_zero_point, std::int8_t* packed_ptr, - int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label( - "Pack (kNeonDotprod, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #1\n" - "dup v27.16b, w1\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "and w2, %w[rows], #-64\n" - "cmp w1, w2\n" - "beq 9f\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #64\n" - "cmp w1, w2\n" - "beq 8f\n" - - "7:\n" - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v4.16b, v4.16b, v26.16b\n" - "eor v5.16b, v5.16b, v26.16b\n" - "eor v6.16b, v6.16b, v26.16b\n" - "eor v7.16b, v7.16b, v26.16b\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - "trn2 v17.4s, v4.4s, v5.4s\n" - "trn1 v18.4s, v6.4s, v7.4s\n" - "trn2 v19.4s, v6.4s, v7.4s\n" - - "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v8.16b, v8.16b, v26.16b\n" - "eor v9.16b, v9.16b, v26.16b\n" - "eor v10.16b, v10.16b, v26.16b\n" - "eor v11.16b, v11.16b, v26.16b\n" - - "trn1 v16.4s, v8.4s, v9.4s\n" - "trn2 v17.4s, v8.4s, v9.4s\n" - "trn1 v18.4s, v10.4s, v11.4s\n" - "trn2 v19.4s, v10.4s, v11.4s\n" - - "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v12.16b, v12.16b, v26.16b\n" - "eor v13.16b, v13.16b, v26.16b\n" - "eor v14.16b, v14.16b, v26.16b\n" - "eor v15.16b, v15.16b, v26.16b\n" - - "trn1 v16.4s, v12.4s, v13.4s\n" - "trn2 v17.4s, v12.4s, v13.4s\n" - "trn1 v18.4s, v14.4s, v15.4s\n" - "trn2 v19.4s, v14.4s, v15.4s\n" - - "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "cmp w1, w2\n" - "bne 7b\n" - - "8:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v4.16b, v4.16b, v26.16b\n" - "eor v5.16b, v5.16b, v26.16b\n" - "eor v6.16b, v6.16b, v26.16b\n" - "eor v7.16b, v7.16b, v26.16b\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - "trn2 v17.4s, v4.4s, v5.4s\n" - "trn1 v18.4s, v6.4s, v7.4s\n" - "trn2 v19.4s, v6.4s, v7.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v8.16b, v8.16b, v26.16b\n" - "eor v9.16b, v9.16b, v26.16b\n" - "eor v10.16b, v10.16b, v26.16b\n" - "eor v11.16b, v11.16b, v26.16b\n" - - "trn1 v16.4s, v8.4s, v9.4s\n" - "trn2 v17.4s, v8.4s, v9.4s\n" - "trn1 v18.4s, v10.4s, v11.4s\n" - "trn2 v19.4s, v10.4s, v11.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v12.16b, v12.16b, v26.16b\n" - "eor v13.16b, v13.16b, v26.16b\n" - "eor v14.16b, v14.16b, v26.16b\n" - "eor v15.16b, v15.16b, v26.16b\n" - - "trn1 v16.4s, v12.4s, v13.4s\n" - "trn2 v17.4s, v12.4s, v13.4s\n" - "trn1 v18.4s, v14.4s, v15.4s\n" - "trn2 v19.4s, v14.4s, v15.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "9:\n" -#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - "cmp w1, w2\n" - "beq 2f\n" - - "1:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "cmp w1, w2\n" - "bne 1b\n" - - "2:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - "cmp w2, #4\n" - "ble 4f\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - "cmp w2, #8\n" - "ble 4f\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - "cmp w2, #12\n" - "ble 4f\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "4:\n" - - "add v28.4s, v28.4s, v29.4s\n" - "add v30.4s, v30.4s, v31.4s\n" - "add v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), - [ src_zero_point ] "r"(static_cast(src_zero_point)), - [ input_xor ] "r"(input_xor) - : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", - "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "v28", "v29", "v30", "v31"); -} - -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "mov w1, #0\n" - - "and w2, %w[rows], #-4\n" - "cmp w1, w2\n" - "beq 3f\n" - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #4\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #4\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #3\n" - "beq 4f\n" - "dup v0.16b, wzr\n" - "dup v1.16b, wzr\n" - "dup v2.16b, wzr\n" - "dup v3.16b, wzr\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ - "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ - "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ - "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "mov x1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp w2, #" #ROW "\n" \ - "beq 4f\n" \ - "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" - - RUY_STORE_ONE_ROW(0, v20) - RUY_STORE_ONE_ROW(1, v21) - RUY_STORE_ONE_ROW(2, v22) - RUY_STORE_ONE_ROW(3, v23) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", - "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#endif - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col, - int output_stride) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "mov r1, #0\n" - "and r2, %[rows], #-4\n" - "cmp r1, r2\n" - "beq 3f\n" -#define RUY_LOAD_FOUR_BY_FOUR() \ - /* Load q0 */ \ - "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \ - /* if src_inc0 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #1\n" \ - "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\ - /* Load q1 */ \ - "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \ - /* if src_inc1 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #2\n" \ - "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\ - /* Load q2 */ \ - "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \ - /* if src_inc2 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #4\n" \ - "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\ - /* Load q3 */ \ - "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \ - /* if src_inc3 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #8\n" \ - "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\ - - RUY_LOAD_FOUR_BY_FOUR() - "add r1, r1, #4\n" - "cmp r1, r2\n" - - "beq 2f\n" - - "1:\n" - "add r1, r1, #4\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - RUY_LOAD_FOUR_BY_FOUR() -#undef RUY_LOAD_FOUR_BY_FOUR - -#define RUY_STORE_FOUR_BY_FOUR() \ - /* Store q8, q10, q9, q11 */ \ - /* q8 = d16, d17 */ \ - "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \ - /* q10 = d20, d21 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \ - /* q9 = d18, d19 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \ - /* q11 = d22, d23 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - - RUY_STORE_FOUR_BY_FOUR() - "cmp r1, r2\n" - - "bne 1b\n" - - "2:\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - RUY_STORE_FOUR_BY_FOUR() -#undef RUY_STORE_FOUR_BY_FOUR - "3:\n" - - "ands r2, %[rows], #3\n" - "beq 4f\n" - "mov r0, #0\n" - // Zero out q0 - q3 - "vdup.32 q0, r0\n" - "vdup.32 q1, r0\n" - "vdup.32 q2, r0\n" - "vdup.32 q3, r0\n" -#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \ - "cmp r2, #" #R "\n" \ - "beq 5f\n" \ - "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \ - "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \ - "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \ - "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n" - -#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \ - "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \ - "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \ - "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \ - "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n" - - RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0) - RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1) - RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0) - RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1) -#undef RUY_LOAD_ONE_ROW_SECOND_HALF -#undef RUY_LOAD_ONE_ROW_FIRST_HALF - "5:\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - "mov r1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp r2, #" #ROW "\n" \ - "beq 4f\n" \ - "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" - - // Store q8 - RUY_STORE_ONE_ROW(0, q8) - // Store q10 - RUY_STORE_ONE_ROW(1, q10) - // Store q9 - RUY_STORE_ONE_ROW(2, q9) - // Store q11 - RUY_STORE_ONE_ROW(3, q11) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr) - : [ src_inc ] "r"(static_cast(src_inc)), - [ rows ] "r"(src_rows), [ stride ] "r"(output_stride) - : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); -} - -#endif // (RUY_PLATFORM(NEON_32) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col) { - profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); - - asm volatile( - // clang-format off - "mov w1, #0\n" - - "and w2, %w[rows], #-4\n" - "cmp w1, w2\n" - "beq 3f\n" - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #4\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #4\n" - - "ldr x10, [%[src_ptr0], #8]\n" - "trn1 v16.4s, v0.4s, v1.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "ldr x11, [%[src_ptr1], #8]\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "ldr x12, [%[src_ptr2], #8]\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "ldr x13, [%[src_ptr3], #8]\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n" - "trn1 v20.2d, v16.2d, v18.2d\n" - "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - "ins v0.d[1], x10\n" - "str q20, [%[packed_ptr], #0]\n" - "ins v1.d[1], x11\n" - "str q21, [%[packed_ptr], #32]\n" - "ins v2.d[1], x12\n" - "str q22, [%[packed_ptr], #64]\n" - "ins v3.d[1], x13\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #3\n" - "beq 4f\n" - "dup v0.16b, wzr\n" - "dup v1.16b, wzr\n" - "dup v2.16b, wzr\n" - "dup v3.16b, wzr\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ - "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ - "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ - "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "mov x1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp w2, #" #ROW "\n" \ - "beq 4f\n" \ - "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" - - RUY_STORE_ONE_ROW(0, v20) - RUY_STORE_ONE_ROW(1, v21) - RUY_STORE_ONE_ROW(2, v22) - RUY_STORE_ONE_ROW(3, v23) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), - [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [src_inc1] "r"(static_cast(src_inc1)), [src_inc2] "r"(static_cast(src_inc2)), - [src_inc3] "r"(static_cast(src_inc3)), [rows] "r"(src_rows) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_arm.h b/tensorflow/lite/experimental/ruy/ruy/pack_arm.h deleted file mode 100644 index f4691d66fcb..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_arm.h +++ /dev/null @@ -1,497 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, - int src_zero_point, std::int8_t* packed_ptr, - int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, - int end_col, std::int32_t* sums_ptr, - int input_xor); - -#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params); -void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params); -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM) - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 4, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - const Scalar* src_ptr2 = src_ptr1 + src_stride; - const Scalar* src_ptr3 = src_ptr2 + src_stride; - int src_inc0 = 16; - int src_inc1 = 16; - int src_inc2 = 16; - int src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * block_col; - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; -#if RUY_PLATFORM(NEON_64) - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Pack8bitNeonInOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } else { - Pack8bitNeonOutOfOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } -#else - // We have a more limited set of general purpose registers in ARMv7, so - // we use the "params" struct technique from the kernel code to save - // registers. - PackParams8bit params; - MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr, - packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - kInputXor, ¶ms); - Pack8bitNeonOutOfOrder4Cols(params); -#endif // RUY_PLATFORM(NEON_64) - } - } -}; - -#endif // (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional -// partial specialization for the RHS, which has a FixedKernelLayout with 2 -// columns. -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 2, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 2) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - int src_inc0 = 16; - int src_inc1 = 16; - if (block_col >= src_matrix.layout.cols - 2) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * block_col; - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - PackParams8bit params; - MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr, - packed_ptr, src_inc0, src_inc1, -1, -1, - src_matrix.layout.rows, src_matrix.zero_point, - kInputXor, ¶ms); - Pack8bitNeonOutOfOrder2Cols(params); - } - } -}; -#endif // (RUY_PLATFORM(NEON_32)) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 8, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - const Scalar* src_ptr2 = src_ptr1 + src_stride; - const Scalar* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & ~7) + - ((block_col & 4) * 4); - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Pack8bitNeonDotprodInOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } else { - Pack8bitNeonDotprodOutOfOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } - } - } -}; -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col); -void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col); - -#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col, - int stride); -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM) - -template <> -struct PackImpl, float, - float, float> { - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 8, 0); - const float zerobuf[4] = {0}; - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - float* packed_ptr = packed_matrix->data + - packed_matrix->layout.stride * (block_col & ~7) + - ((block_col & 4)); -#if RUY_PLATFORM(NEON_64) - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, - src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col); - } else { - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, - src_inc0, src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col); - } -#else - // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc - // to save on registers (we have fewer general purpose registers in - // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four - // values that are each either 16 or 0 and use them directly. For the - // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should - // use the value 16 (bit is set) or 0 (bit is not set) for the - // respective increment value. - std::int64_t src_inc = 0; - src_inc += src_inc0 == 16 ? 1 : 0; - src_inc += src_inc1 == 16 ? 2 : 0; - src_inc += src_inc2 == 16 ? 4 : 0; - src_inc += src_inc3 == 16 ? 8 : 0; - const int kOutputStride = 32; - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, kOutputStride); -#endif // RUY_PLATFORM(NEON_64) - } - } -}; - -#if RUY_PLATFORM(NEON_32) -// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional -// specialization for a FixedKernelLayout with 4 columns. -template <> -struct PackImpl, float, - float, float> { - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 4, 0); - const float zerobuf[4] = {0}; - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - float* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * (block_col); - // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc - // to save registers. - std::int64_t src_inc = 0; - src_inc += src_inc0 == 16 ? 1 : 0; - src_inc += src_inc1 == 16 ? 2 : 0; - src_inc += src_inc2 == 16 ? 4 : 0; - src_inc += src_inc3 == 16 ? 8 : 0; - const int kOutputStride = 16; - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, kOutputStride); - } - } -}; -#endif // (RUY_PLATFORM(NEON_32)) -#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ - // RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc deleted file mode 100644 index 3575943e50e..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_avx2.cc +++ /dev/null @@ -1,816 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvx2 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -using PackImplFloatAvx2 = - PackImpl, float, - float, float>; - -namespace { - -inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, - const std::int8_t* addr) { - RUY_DCHECK_LT(available_src_rows, 32); - __m256i padded_data; - - if (available_src_rows >= 16) { - __m128i load_hi = _mm_set1_epi8(zero_point); - __m128i load_lo = _mm_loadu_si128(reinterpret_cast(addr)); - memcpy(&load_hi, addr + 16, available_src_rows - 16); - padded_data = _mm256_set_m128i(load_hi, load_lo); - } else { - __m128i load_hi = _mm_set1_epi8(zero_point); - __m128i load_lo = load_hi; - memcpy(&load_lo, addr, available_src_rows); - padded_data = _mm256_set_m128i(load_hi, load_lo); - } - return padded_data; -} - -inline void Pack8bitAvx2Packer(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitAvx2::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have Layout::kCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: Layout::kCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - std::int32_t sums_adjustment = 0; - const __m256i ones_16bit = _mm256_set1_epi16(1); - __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); - __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - if (sums_ptr) { - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); - t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); - t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); - t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); - t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); - t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); - t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); - t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - __m256i sums_4x4_16bit_lo; - sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); - - // The sums have been performed across columns, and now we have 4x16-bit - // sums packed together. We use madd for pairwise 32-bit sums. - const __m256i sums_4x2_32bit_lo_new = - _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); - sums_4x2_32bit_lo = - _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); - - __m256i sums_4x4_16bit_hi; - sums_4x4_16bit_hi = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); - - const __m256i sums_4x2_32bit_hi_new = - _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); - sums_4x2_32bit_hi = - _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), - r7); - } else { - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); - t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); - t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); - t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); - t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); - t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); - t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); - t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), - r7); - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // We compensate for padding-with-zero_point by initializing the - // summations with the compensating offset, effectively - // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - // - // Note that (zero_point ^ input_xor) is performed in 8-bits and then - // cast. - sums_adjustment += - -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); - - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = MaskLoadu(available_src_rows, zero_point, src_ptr0); - t4 = MaskLoadu(available_src_rows, zero_point, src_ptr4); - t1 = MaskLoadu(available_src_rows, zero_point, src_ptr1); - t5 = MaskLoadu(available_src_rows, zero_point, src_ptr5); - t2 = MaskLoadu(available_src_rows, zero_point, src_ptr2); - t6 = MaskLoadu(available_src_rows, zero_point, src_ptr6); - t3 = MaskLoadu(available_src_rows, zero_point, src_ptr3); - t7 = MaskLoadu(available_src_rows, zero_point, src_ptr7); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - __m256i sums_4x4_16bit_lo; - sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); - - // The sums have been performed across columns, and now we have 4x16-bit - // sums packed together. We use madd for pairwise 32-bit sums. - const __m256i sums_4x2_32bit_lo_new = - _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); - sums_4x2_32bit_lo = - _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); - - __m256i sums_4x4_16bit_hi; - sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); - - const __m256i sums_4x2_32bit_hi_new = - _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); - sums_4x2_32bit_hi = - _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), - r7); - } - - packed_ptr += 8 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - - if (sums_ptr) { - const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); - - __m256i sums = - _mm256_loadu_si256(reinterpret_cast(sums_ptr)); - const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - - // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the - // neighbours, finshing up by adding them to the stored accumulated sums. - const __m256i sums_2x4_32bit_lo = - _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx); - const __m256i sums_2x4_32bit_hi = - _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx); - const __m256i sums_2x4_32bit_a = - _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); - const __m256i sums_2x4_32bit_b = - _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); - sums = _mm256_add_epi32(sums, sums_adjustment_v); - sums = _mm256_add_epi32(sums, sums_2x4_32bit_a); - sums = _mm256_add_epi32(sums, sums_2x4_32bit_b); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); - } -} - -inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { - return _mm256_castpd_ps( - _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); -} - -inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { - return _mm256_castpd_ps( - _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); -} - -inline void PackFloatAvx2Packer(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kCols, 8); - RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kRows, 1); - - // This packing amounts to transposition of 8x8 blocks. - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - // Handle cases where source does not have kPackDim (8) columns. - if (remaining_src_cols < kPackCols) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += kPackRows) { - const int available_src_rows = src_rows - k; - // Effectively, - // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); - // but treat each case separately. - if (available_src_rows >= kPackRows) { - __m256 t0, t1, t2, t3, t4, t5, t6, t7; - __m256 r0, r1, r2, r3, r4, r5, r6, r7; - - t0 = _mm256_loadu_ps(src_ptr0); - t4 = _mm256_loadu_ps(src_ptr4); - t1 = _mm256_loadu_ps(src_ptr1); - t5 = _mm256_loadu_ps(src_ptr5); - t2 = _mm256_loadu_ps(src_ptr2); - t6 = _mm256_loadu_ps(src_ptr6); - t3 = _mm256_loadu_ps(src_ptr3); - t7 = _mm256_loadu_ps(src_ptr7); - - r0 = _mm256_unpacklo_ps(t0, t1); - r4 = _mm256_unpacklo_ps(t4, t5); - r2 = _mm256_unpackhi_ps(t0, t1); - r6 = _mm256_unpackhi_ps(t4, t5); - r1 = _mm256_unpacklo_ps(t2, t3); - r5 = _mm256_unpacklo_ps(t6, t7); - r3 = _mm256_unpackhi_ps(t2, t3); - r7 = _mm256_unpackhi_ps(t6, t7); - - t0 = Mm256UnpackloPsx2(r0, r1); - t4 = Mm256UnpackloPsx2(r4, r5); - t2 = Mm256UnpackhiPsx2(r0, r1); - t6 = Mm256UnpackhiPsx2(r4, r5); - t1 = Mm256UnpackloPsx2(r2, r3); - t5 = Mm256UnpackloPsx2(r6, r7); - t3 = Mm256UnpackhiPsx2(r2, r3); - t7 = Mm256UnpackhiPsx2(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by 16 - // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) - // are interleaved to create (r0, r1). This complexity follows from the - // way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2f128_ps(t0, t4, 0x20); - r4 = _mm256_permute2f128_ps(t1, t5, 0x20); - r1 = _mm256_permute2f128_ps(t0, t4, 0x31); - r5 = _mm256_permute2f128_ps(t1, t5, 0x31); - r2 = _mm256_permute2f128_ps(t2, t6, 0x20); - r6 = _mm256_permute2f128_ps(t3, t7, 0x20); - r3 = _mm256_permute2f128_ps(t2, t6, 0x31); - r7 = _mm256_permute2f128_ps(t3, t7, 0x31); - - _mm256_storeu_ps(packed_ptr + 0 * 8, r0); - _mm256_storeu_ps(packed_ptr + 2 * 8, r4); - _mm256_storeu_ps(packed_ptr + 4 * 8, r1); - _mm256_storeu_ps(packed_ptr + 6 * 8, r5); - _mm256_storeu_ps(packed_ptr + 1 * 8, r2); - _mm256_storeu_ps(packed_ptr + 3 * 8, r6); - _mm256_storeu_ps(packed_ptr + 5 * 8, r3); - _mm256_storeu_ps(packed_ptr + 7 * 8, r7); - } else if (available_src_rows > 0) { - const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); - const __m256i row_mask_v = - _mm256_cmpgt_epi32(_mm256_set1_epi32(available_src_rows), series); - - __m256 t0, t1, t2, t3, t4, t5, t6, t7; - __m256 r0, r1, r2, r3, r4, r5, r6, r7; - - t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); - t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); - t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); - t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); - t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); - t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); - t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); - t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); - - r0 = _mm256_unpacklo_ps(t0, t1); - r4 = _mm256_unpacklo_ps(t4, t5); - r2 = _mm256_unpackhi_ps(t0, t1); - r6 = _mm256_unpackhi_ps(t4, t5); - r1 = _mm256_unpacklo_ps(t2, t3); - r5 = _mm256_unpacklo_ps(t6, t7); - r3 = _mm256_unpackhi_ps(t2, t3); - r7 = _mm256_unpackhi_ps(t6, t7); - - t0 = Mm256UnpackloPsx2(r0, r1); - t4 = Mm256UnpackloPsx2(r4, r5); - t2 = Mm256UnpackhiPsx2(r0, r1); - t6 = Mm256UnpackhiPsx2(r4, r5); - t1 = Mm256UnpackloPsx2(r2, r3); - t5 = Mm256UnpackloPsx2(r6, r7); - t3 = Mm256UnpackhiPsx2(r2, r3); - t7 = Mm256UnpackhiPsx2(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by 16 - // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) - // are interleaved to create (r0, r1). This complexity follows from the - // way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2f128_ps(t0, t4, 0x20); - r4 = _mm256_permute2f128_ps(t1, t5, 0x20); - r1 = _mm256_permute2f128_ps(t0, t4, 0x31); - r5 = _mm256_permute2f128_ps(t1, t5, 0x31); - r2 = _mm256_permute2f128_ps(t2, t6, 0x20); - r6 = _mm256_permute2f128_ps(t3, t7, 0x20); - r3 = _mm256_permute2f128_ps(t2, t6, 0x31); - // r7 no longer needed. - - _mm256_storeu_ps(trailing_buf + 0 * 8, r0); - _mm256_storeu_ps(trailing_buf + 2 * 8, r4); - _mm256_storeu_ps(trailing_buf + 4 * 8, r1); - _mm256_storeu_ps(trailing_buf + 6 * 8, r5); - _mm256_storeu_ps(trailing_buf + 1 * 8, r2); - _mm256_storeu_ps(trailing_buf + 3 * 8, r6); - _mm256_storeu_ps(trailing_buf + 5 * 8, r3); - // No store to (trailing_buf + 7 * 8), space not allocated. - } - - packed_ptr += kPackRows * kPackCols; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -} // namespace. - -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvx2 8bit"); - - using Layout = PackImpl8bitAvx2::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - static constexpr int kNumRowChunks = 8; // Short input is padded. - - // Each packed block is 4*8, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - Pack8bitAvx2Packer(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvx2 float"); - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - float trailing_buf[(kPackRows - 1) * kPackCols]; - if (remaining_src_cols < 8) { - memset(trailing_buf, 0, sizeof(trailing_buf)); - } - PackFloatAvx2Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - - const int trailing_rows = src_rows & (kPackRows - 1); - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~(kPackRows - 1); - memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, - kPackCols * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc deleted file mode 100644 index d5636572eed..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_avx512.cc +++ /dev/null @@ -1,693 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvx512 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -namespace { - -inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, - std::int8_t* packed_ptr) { - using Layout = PackImpl8bitAvx512::Layout; - static constexpr int kHalfLayoutCols = - PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a - // block. - RUY_DCHECK_EQ(kHalfLayoutCols, 8); - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - - const int non_trailing_blocks = (src_rows & ~31) >> 2; - // This routine fills half blocks, and typically fills the second halves. - // Thus packed_ptr is already offset by 8 * 4. - for (int k = 0; k < non_trailing_blocks; ++k) { - for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) { - packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point; - } - } -} - -inline __m512i LoaduTwo(const std::int8_t* addr_lo, - const std::int8_t* addr_hi) { - __m512i lower_filled = _mm512_castsi256_si512(_mm256_loadu_epi8(addr_lo)); - return _mm512_inserti32x8(lower_filled, _mm256_loadu_epi8(addr_hi), 1); -} - -inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, - const std::int8_t* addr_lo, - const std::int8_t* addr_hi) { - const __m512i lower_filled = _mm512_castsi256_si512( - _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo)); - return _mm512_inserti32x8( - lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi), - 1); -} - -inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitAvx512::Layout; - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have kHalfLayoutCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: kHalfLayoutCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - std::int32_t sums_adjustment = 0; - const __m512i ones_16bit = _mm512_set1_epi16(1); - __m512i sums_8x2_32bit = _mm512_set1_epi32(0); - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { - // m: {0, 1} for 2 chunks of rows. - for (int m = 0; m < 2; ++m) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - // i: chunks, s: Layout::Rows. - if (sums_ptr) { - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - - __m512i sums_8x4_16bit; - sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); - // The sums have been performed across columns, and now we have - // 4x16-bit sums packed together. We use madd for pairwise 32-bit - // sums. - const __m512i sums_8x2_32bit_new = - _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); - sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); - - _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); - } else { - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); - const __mmask32 row_mask = - (static_cast(1) << available_src_rows) - 1; - - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // We compensate for padding-with-zero_point by initializing the - // summations with the compensating offset, effectively - // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - // - // Note that (zero_point ^ input_xor) is performed in 8-bits and then - // cast. - sums_adjustment += -(zero_point ^ input_xor) * 4 * - (8 - ((available_src_rows + 3) >> 2)); - - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - const __m256i zero_point_v = _mm256_set1_epi8(zero_point); - - t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); - t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); - t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); - t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - - __m512i sums_8x4_16bit; - sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); - // The sums have been performed across columns, and now we have - // 4x16-bit sums packed together. We use madd for pairwise 32-bit - // sums. - const __m512i sums_8x2_32bit_new = - _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); - sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); - - _mm256_storeu_epi8(trailing_buf + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(trailing_buf + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(trailing_buf + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(trailing_buf + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(trailing_buf + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(trailing_buf + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(trailing_buf + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(trailing_buf + 7 * 16 * 4, r3_1); - } - - packed_ptr += 16 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } - - if (sums_ptr) { - const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); - - __m256i sums = _mm256_loadu_epi32(sums_ptr); - const __m512i idx = - _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); - - // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the - // neighbours, finshing up by adding them to the stored accumulated sums. - const __m512i sums_2x8_32bit = - _mm512_permutexvar_epi32(idx, sums_8x2_32bit); - sums = _mm256_add_epi32(sums, sums_adjustment_v); - sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); - sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); - - _mm256_storeu_epi32(sums_ptr, sums); - } -} - -inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { - const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); - return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); -} - -inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo, - const float* addr_hi) { - const __m512 lower_filled = - _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo)); - return _mm512_insertf32x8(lower_filled, - _mm256_maskz_loadu_ps(row_mask, addr_hi), 1); -} - -inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) { - return _mm512_castpd_ps( - _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); -} - -inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) { - return _mm512_castpd_ps( - _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); -} - -inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += 16) { - for (int m = 0; m < 2; ++m) { - const int available_src_rows = src_rows - k - 8 * m; - // Effectively, - // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (available_src_rows > 7) { - __m512 t0, t1, t2, t3; - __m512 r0, r1, r2, r3; - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_ps(t0, t1); - r2 = _mm512_unpackhi_ps(t0, t1); - r1 = _mm512_unpacklo_ps(t2, t3); - r3 = _mm512_unpackhi_ps(t2, t3); - - t0 = Mm512UnpackloPsx2(r0, r1); - t2 = Mm512UnpackhiPsx2(r0, r1); - t1 = Mm512UnpackloPsx2(r2, r3); - t3 = Mm512UnpackhiPsx2(r2, r3); - - r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); - - _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0)); - _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); - _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1)); - _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); - _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2)); - _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); - _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3)); - _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1)); - } else if (available_src_rows > 0) { - const __mmask8 row_mask = - (static_cast(1) << available_src_rows) - 1; - - __m512 t0, t1, t2, t3; - __m512 r0, r1, r2, r3; - - t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4); - t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5); - t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6); - t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_ps(t0, t1); - r2 = _mm512_unpackhi_ps(t0, t1); - r1 = _mm512_unpacklo_ps(t2, t3); - r3 = _mm512_unpackhi_ps(t2, t3); - - t0 = Mm512UnpackloPsx2(r0, r1); - t2 = Mm512UnpackhiPsx2(r0, r1); - t1 = Mm512UnpackloPsx2(r2, r3); - t3 = Mm512UnpackhiPsx2(r2, r3); - - r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); - - _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0)); - _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); - _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1)); - _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); - _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2)); - _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); - _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3)); - // Do not store _mm512_extractf32x8_ps(r3, 1). - } - - packed_ptr += 16 * 8; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) { - const int non_trailing_rows = src_rows & ~7; - for (int k = 0; k < non_trailing_rows; ++k) { - for (int j = 0; j < 8; ++j) { - packed_ptr[j] = 0.0f; - } - packed_ptr += 16; - } -} - -} // namespace. - -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvx512 8bit"); - - using Layout = PackImpl8bitAvx512::Layout; - constexpr int kHalfBlockOffset = 32; - RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); - static constexpr int kHalfLayoutCols = - PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a - // block. - RUY_DCHECK_EQ(kHalfLayoutCols, 8); - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - - // Each packed block is 4*16, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - std::int32_t* second_sums_ptr = - sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; - if (remaining_src_cols > kHalfLayoutCols) { - HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor, - zerobuf, src_stride, - remaining_src_cols - kHalfLayoutCols, src_rows, - packed_ptr + kHalfBlockOffset, second_sums_ptr, - trailing_buf + kHalfBlockOffset); - } else { - HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor, - packed_ptr + kHalfBlockOffset); - // The kernel may not need the second half-blocks sums to be set. - if (second_sums_ptr) { - for (int i = 0; i < kHalfLayoutCols; ++i) { - second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); - } - } - } - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvx512 float"); - float trailing_buf[7 * 16]; - if (remaining_src_cols > 8) { - HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride, - remaining_src_cols - 8, src_rows, packed_ptr + 8, - trailing_buf + 8); - } else { - memset(trailing_buf, 0, sizeof(trailing_buf)); - HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - ZeroHalfFloatAvx512(src_rows, packed_ptr + 8); - } - const int trailing_rows = src_rows & 7; - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~7; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc b/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc deleted file mode 100644 index 49b4a1f978c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_avxvnni.cc +++ /dev/null @@ -1,478 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvxVnni = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -namespace { - -inline void ZeroHalf8bitAvxVnni(int src_rows, std::int8_t packed_zero_point, - std::int8_t* packed_ptr) { - const int non_trailing_blocks = (src_rows & ~31) >> 2; - // This routine fills half blocks, and typically fills the second halves. Thus - // packed_ptr is already offset by 8*4. - for (int k = 0; k < non_trailing_blocks; ++k) { - for (int j = 0; j < (8 * 4); ++j) { - packed_ptr[16 * 4 * k + j] = packed_zero_point; - } - } -} - -inline void HalfPack8bitAvxVnni(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - std::int8_t in_data[8][8][4]; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8 * 4; - std::int64_t src_inc1 = 8 * 4; - std::int64_t src_inc2 = 8 * 4; - std::int64_t src_inc3 = 8 * 4; - std::int64_t src_inc4 = 8 * 4; - std::int64_t src_inc5 = 8 * 4; - std::int64_t src_inc6 = 8 * 4; - std::int64_t src_inc7 = 8 * 4; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += 16 * 4) { - for (int m = 0; m < 2; ++m) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int packed_rows = src_rows - k - 8 * m * 4; - // Effectively, - // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (packed_rows >= (8 * 4)) { - for (int i = 0; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - packed_ptr[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - } - if (sums_ptr) { - for (int s = 0; s < 4; ++s) { - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } - } else if (packed_rows > 0) { - RUY_DCHECK_LT(packed_rows >> 2, 8); - int i = 0; - for (; i < (packed_rows >> 2); ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - if (i < ((packed_rows + 3) >> 2)) { - int s = 0; - for (; s < (packed_rows & 3); ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - RUY_DCHECK_LE(s, 4); - for (; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = zero_point; - } - } - ++i; - } - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // It might prove better in optimized code to pad uniformly with - // zero_point, and compensate by initializing the summations with the - // compensating offset, effectively - // ((input_xor - zero_point) ^ input_xor) * - // 4 * (8 - ((packed_rows + 3) >> 2)). - for (; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = input_xor; - } - } - } - // We loop through [0, 8) rather than [0, (packed_rows + 3) >> 2), since - // that emulates what we might do in fully-optimized code. - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } else { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - } - } - } - } - } - - packed_ptr += 16 * 8 * 4; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void HalfPackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - float in_data[8][8]; - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += 16) { - for (int m = 0; m < 2; ++m) { - const int packed_rows = src_rows - k - 8 * m; - // Effectively, - // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (packed_rows > 7) { - for (int i = 0; i < 8; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - packed_ptr[16 * i + j] = in_data[j][i]; - } - } - } else if (packed_rows > 0) { - for (int i = 0; i < packed_rows; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = packed_rows; i < 8; ++i) { - in_data[0][i] = 0.0f; - in_data[1][i] = 0.0f; - in_data[2][i] = 0.0f; - in_data[3][i] = 0.0f; - in_data[4][i] = 0.0f; - in_data[5][i] = 0.0f; - in_data[6][i] = 0.0f; - in_data[7][i] = 0.0f; - } - // We loop through [0, 7) rather than [0, packed_rows), since that - // emulates what we might do in fully-optimized code. - for (int i = 0; i < 7; ++i) { - for (int j = 0; j < 8; ++j) { - trailing_buf[16 * i + j] = in_data[j][i]; - } - } - } - - packed_ptr += 16 * 8; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void ZeroHalfFloatAvxVnni(int src_rows, float* packed_ptr) { - const int non_trailing_rows = src_rows & ~7; - for (int k = 0; k < non_trailing_rows; ++k) { - for (int j = 0; j < 8; ++j) { - packed_ptr[j] = 0.0f; - } - packed_ptr += 16; - } -} - -} // namespace. - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvxVnni 8bit (UNFINISHED)"); - - // Each packed block is 4*16, and there are normally 8. The trailing block is - // only slightly shorter. - std::int8_t trailing_buf[8 * 16 * 4]; - memset(trailing_buf, 0, 8 * 16 * 4 * sizeof(std::int8_t)); - - std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + 8 : nullptr; - if (remaining_src_cols > 8) { - HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - HalfPack8bitAvxVnni(src_ptr + src_stride * 8, input_xor, zerobuf, - src_stride, remaining_src_cols - 8, src_rows, - packed_ptr + 8 * 4, second_sums_ptr, - trailing_buf + 8 * 4); - } else { - HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - ZeroHalf8bitAvxVnni(src_rows, zerobuf[0] ^ input_xor, packed_ptr + 8 * 4); - // The kernel may not need the second half-blocks sums to be set. - if (second_sums_ptr) { - for (int i = 0; i < 8; ++i) { - second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); - } - } - } - const bool trailing_data = (src_rows & 31) > 0; - // If the number of source rows is not a multiple of 32, there will be data in - // the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~31; - // Destination "rows" are padded to next highest multiple of 4. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(std::int8_t)); - } -} - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvxVnni float (UNFINISHED)"); - float trailing_buf[7 * 16]; - if (remaining_src_cols > 8) { - HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - HalfPackFloatAvxVnni(src_ptr + src_stride * 8, zerobuf, src_stride, - remaining_src_cols - 8, src_rows, packed_ptr + 8, - trailing_buf + 8); - } else { - memset(trailing_buf, 0, sizeof(trailing_buf)); - HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - ZeroHalfFloatAvxVnni(src_rows, packed_ptr + 8); - } - const int trailing_rows = src_rows & 7; - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~7; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_common.h b/tensorflow/lite/experimental/ruy/ruy/pack_common.h deleted file mode 100644 index 91d47af8a5f..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_common.h +++ /dev/null @@ -1,246 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -template -struct PackedTypeImpl { - using Type = Scalar; -}; - -#if RUY_PLATFORM(NEON_32) -struct PackParams8bit { - const void* src_ptr0; - const void* src_ptr1; - const void* src_ptr2; - const void* src_ptr3; - const std::int32_t* sums_ptr; - const std::int8_t* packed_ptr; - int src_inc0; - int src_inc1; - int src_inc2; - int src_inc3; - int src_rows; - int src_zero_point; - int input_xor; -}; - -inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - const std::int32_t* sums_ptr, - const std::int8_t* packed_ptr, int src_inc0, - int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, int input_xor, - PackParams8bit* params) { - params->src_ptr0 = src_ptr0; - params->src_ptr1 = src_ptr1; - params->src_ptr2 = src_ptr2; - params->src_ptr3 = src_ptr3; - params->sums_ptr = sums_ptr; - params->packed_ptr = packed_ptr; - params->src_inc0 = src_inc0; - params->src_inc1 = src_inc1; - params->src_inc2 = src_inc2; - params->src_inc3 = src_inc3; - params->src_rows = src_rows; - params->src_zero_point = src_zero_point; - params->input_xor = input_xor; -} -#endif - -#if RUY_PLATFORM(NEON) -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -#elif RUY_PLATFORM(X86) -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -#endif - -template -using PackedType = typename PackedTypeImpl::Type; - -template -PackedScalar Pack(Scalar x) { - return x - SymmetricZeroPoint() + SymmetricZeroPoint(); -} - -template -struct PackImpl {}; - -#define RUY_INHERIT_PACK(PARENT, CHILD) \ - template \ - struct PackImpl \ - : PackImpl { \ - }; - -template -struct PackImpl { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (generic)"); - RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0); - SumsType* sums = packed_matrix->sums; - for (int col = start_col; col < end_col; col++) { - SumsType accum = 0; - for (int row = 0; row < packed_matrix->layout.rows; row++) { - PackedScalar packed_val; - if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) { - packed_val = Pack(Element(src_matrix, row, col)); - } else { - packed_val = packed_matrix->zero_point; - } - accum += packed_val; - *ElementPtr(packed_matrix, row, col) = packed_val; - } - if (sums) { - sums[col] = accum; - } - } - } -}; - -#if RUY_PLATFORM(NEON) -RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon) -RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod) -#elif RUY_PLATFORM(X86) -RUY_INHERIT_PACK(Path::kStandardCpp, Path::kSse42) -RUY_INHERIT_PACK(Path::kSse42, Path::kAvx2) -RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512) -RUY_INHERIT_PACK(Path::kAvx512, Path::kAvxVnni) -#endif - -// Main entry point for packing. -template -void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix, - int start_col, int end_col) { - using SumsType = typename PackedMatrix::SumsType; - Matrix src = ToMatrix(src_matrix); - PackedMatrix packed = - ToPackedMatrix(*packed_matrix); - PackImpl::Run( - tuning, src, &packed, start_col, end_col); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc b/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc deleted file mode 100644 index ecd1cf83c6d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_sse42.cc +++ /dev/null @@ -1,471 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitSse42 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -using PackImplFloatSse42 = - PackImpl, float, - float, float>; - -namespace { - -inline void Pack8bitSse42Packer(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - std::int8_t in_data[Layout::kCols][kNumRowChunks][Layout::kRows]; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have Layout::kCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: Layout::kCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - // i: chunks, s: Layout::Rows. - for (int i = 0; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - // i: chunks, j: Layout::kCols, s: Layout::Rows. - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - // 8 * 4 * i is offset for each block, that is - // (Layout::kCols * Layout::kRows * i) - packed_ptr[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - } - if (sums_ptr) { - for (int s = 0; s < 4; ++s) { - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); - int i = 0; - // Consume chunks of 4 rows that are complete. - for (; i < (available_src_rows >> 2); ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - // Consume any incomplete chunk. - if (i < ((available_src_rows + 3) >> 2)) { - int s = 0; - for (; s < (available_src_rows & 3); ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - RUY_DCHECK_LE(s, 4); - for (; s < 4; ++s) { - // j: Layout::kCols. - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = zero_point; - } - } - ++i; - } - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // It might prove better in optimized code to pad uniformly with - // zero_point, and compensate by initializing the summations with the - // compensating offset, effectively - // ((input_xor - zero_point) ^ input_xor) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - for (; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = input_xor; - } - } - } - // We loop through [0, 8) rather than - // [0, (available_src_rows + 3) >> 2), since that emulates what we might - // do in fully-optimized code. - // - // i: chunks, j: Layout::kCols, s: Layout::Rows. - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor); - } - } - } - } else { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - } - } - } - } - } - - packed_ptr += 8 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -inline void PackFloatSse42Packer(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - using Layout = PackImplFloatSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 1); - - // This packing amounts to tranposition of 8x8 blocks. - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - - float in_data[kPackCols][kPackRows]; - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - // Handle cases where source does not have kPackDim (8) columns. - if (remaining_src_cols < kPackCols) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += kPackRows) { - const int available_src_rows = src_rows - k; - // Effectively, - // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); - // but treat each case separately. - if (available_src_rows >= kPackRows) { - for (int i = 0; i < 8; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - packed_ptr[8 * i + j] = in_data[j][i]; - } - } - } else if (available_src_rows > 0) { - for (int i = 0; i < available_src_rows; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = available_src_rows; i < kPackRows; ++i) { - in_data[0][i] = 0.0f; - in_data[1][i] = 0.0f; - in_data[2][i] = 0.0f; - in_data[3][i] = 0.0f; - in_data[4][i] = 0.0f; - in_data[5][i] = 0.0f; - in_data[6][i] = 0.0f; - in_data[7][i] = 0.0f; - } - // We loop through [0, 7) rather than [0, packed_rows), since that - // emulates what we might do in fully-optimized code. - // i: (kPackRows - 1), j: kPackCols. - for (int i = 0; i < 7; ++i) { - for (int j = 0; j < 8; ++j) { - trailing_buf[kPackRows * i + j] = in_data[j][i]; - } - } - } - - packed_ptr += kPackRows * kPackCols; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -} // namespace. - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kSse42 8bit (UNFINISHED)"); - - using Layout = PackImpl8bitSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - static constexpr int kNumRowChunks = 8; // Short input is padded. - - // Each packed block is 4*8, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - Pack8bitSse42Packer(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kSse42 float (UNFINISHED)"); - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - float trailing_buf[(kPackRows - 1) * kPackCols]; - if (remaining_src_cols < 8) { - memset(trailing_buf, 0, sizeof(trailing_buf)); - } - PackFloatSse42Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - - const int trailing_rows = src_rows & (kPackRows - 1); - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~(kPackRows - 1); - memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, - kPackCols * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pack_x86.h b/tensorflow/lite/experimental/ruy/ruy/pack_x86.h deleted file mode 100644 index 8bdc88e5763..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pack_x86.h +++ /dev/null @@ -1,461 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/pack_common.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (SSE 4.2 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[Layout::kCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - Layout::kCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitSse42(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, float, - float, float> { - using Layout = FixedKernelLayout; - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (SSE 4.2 float)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatSse42(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr); - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX2 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[Layout::kCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - Layout::kCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvx2(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, float, - float, float> { - using Layout = FixedKernelLayout; - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX2 float)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr int kHalfLayoutCols = - 8; // Half the number of cols in a block. - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvx512(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, - float, float, float> { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 float)"); - using Layout = FixedKernelLayout; - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr int kHalfLayoutCols = - 8; // Half the number of cols in a block. - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvxVnni(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr); - -template <> -struct PackImpl, - float, float, float> { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 float)"); - - using Layout = FixedKernelLayout; - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/path.h b/tensorflow/lite/experimental/ruy/ruy/path.h deleted file mode 100644 index 5973b8040a7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/path.h +++ /dev/null @@ -1,162 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -namespace ruy { - -// A Path is a choice of implementation path, e.g. between reference code -// and optimized code, or between different optimized code paths using different -// instruction sets. -// -// It's important that any symbol that depends on such implementation -// details, is somehow templatized in such a Path, so that different Path values -// yield different symbols, so we never have the situation where a symbols has -// multiple inequivalent definitions based on which code paths are compiled. -// That would be a violation of the ODR (One Definition Rule) which is Undefined -// Behavior, and one of the most serious issues plaguing both Eigen and -// gemmlowp. -// -// This enum is actually a bit-field: aside from kNone, all other values are -// powers of two, thus are one bit each. We define bit-wise operators below -// for this enum. Some places in Ruy accept a Path bit-field where multiple -// Paths may be selected, while some other places require a single Path (i.e. -// just one of the enum values here). Typically, user-facing parts of Ruy -// accept arbitrary bit-fields, allowing the user to compile support for -// multiple paths and to inform Ruy of all the paths that are to be enabled -// at runtime; then, typically in dispatch.h, we internally pick one -// specific path and from there on, internal Ruy code deals with only one -// path. -// -// When a user selects a set of compiled paths, Ruy internally dispatches to the -// "best" one, which typically means the newest optimized instructions for a -// given base architecture (such as ARM). Higher values of this enum correspond -// to "better" code paths within a given base architecture for which Ruy has -// optimized code paths. -// -// Values are reused across architectures. -// Rationale: Scale better to N architectures, it is good to have small values -// both for the compile-time logic to select paths, and when manually spelling -// out Path values, such as when invoking a test or benchmark. -enum class Path : std::uint8_t { - // This is a special null value, representing the absence of any path. - kNone = 0, - // Reference multiplication code. - // The main purpose of this path is to have a very simple standalone Mul - // implementation to check against. - // This path bypasses almost all of Ruy's internal implementation details. - // - // This is intended for testing/development. - kReference = 0x1, - // Standard C++ implementation of Ruy's architecture-specific parts. - // Unlike Path::kReference, this path exercises most of Ruy's internal logic. - // - // This is intended for testing/development. - kStandardCpp = 0x2, - -#if RUY_PLATFORM(ARM) - // ARM architectures. - // - // Optimized path using a widely available subset of ARM NEON instructions. - kNeon = 0x4, - // Optimized path making use of ARM NEON dot product instructions that are - // available on newer ARM cores. - kNeonDotprod = 0x8, -#endif // RUY_PLATFORM(ARM) - -#if RUY_PLATFORM(X86) - // x86 architectures. - // - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. - // Optimization is not finished. In particular the dimensions of the kernel - // blocks can be changed as desired. - // - // Optimized for SSE 4.2. - kSse42 = 0x4, - // Optimized for AVX2. - kAvx2 = 0x8, - // Optimized for AVX-512. - kAvx512 = 0x10, - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. - // Optimization is not finished. In particular the dimensions of the kernel - // blocks can be changed as desired. - // - // Optimized for AVX-VNNI. - kAvxVnni = 0x20, -#endif // RUY_PLATFORM(X86) -}; - -inline constexpr Path operator|(Path p, Path q) { - return static_cast(static_cast(p) | - static_cast(q)); -} - -inline constexpr Path operator&(Path p, Path q) { - return static_cast(static_cast(p) & - static_cast(q)); -} - -inline constexpr Path operator^(Path p, Path q) { - return static_cast(static_cast(p) ^ - static_cast(q)); -} - -inline constexpr Path operator~(Path p) { - return static_cast(~static_cast(p)); -} - -inline Path GetMostSignificantPath(Path path_mask) { - return static_cast(round_down_pot(static_cast(path_mask))); -} - -// ruy::kAllPaths represents all Path's that make sense to on a given -// base architecture. -#ifdef __linux__ -#if RUY_PLATFORM(NEON_64) -constexpr Path kAllPaths = - Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod; -#elif RUY_PLATFORM(NEON_32) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; -#elif RUY_PLATFORM(X86) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | - Path::kSse42 | Path::kAvx2 | Path::kAvx512 | - Path::kAvxVnni; -#else -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; -#endif -#else // __linux__ -// We don't know how to do runtime dotprod detection outside of linux for now. -#if RUY_PLATFORM(NEON) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; -#elif RUY_PLATFORM(X86) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | - Path::kSse42 | Path::kAvx2 | Path::kAvx512 | - Path::kAvxVnni; -#else -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; -#endif -#endif // __linux__ - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/platform.h b/tensorflow/lite/experimental/ruy/ruy/platform.h deleted file mode 100644 index d6e86e6a792..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/platform.h +++ /dev/null @@ -1,156 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ - -#ifdef __ANDROID_NDK__ -#include -#endif - -#define RUY_PLATFORM(X) ((RUY_DONOTUSEDIRECTLY_##X) != 0) - -// Architecture-level platform detection. -// -// Ruy requires these to be mutually exclusive. - -// Detect x86. -#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \ - defined(__x86__) || defined(__X86__) || defined(_X86_) || \ - defined(_M_IX86) || defined(_M_X64) -#define RUY_DONOTUSEDIRECTLY_X86 1 -#else -#define RUY_DONOTUSEDIRECTLY_X86 0 -#endif - -// Detect ARM 32-bit. -#ifdef __arm__ -#define RUY_DONOTUSEDIRECTLY_ARM_32 1 -#else -#define RUY_DONOTUSEDIRECTLY_ARM_32 0 -#endif - -// Detect ARM 64-bit. -#ifdef __aarch64__ -#define RUY_DONOTUSEDIRECTLY_ARM_64 1 -#else -#define RUY_DONOTUSEDIRECTLY_ARM_64 0 -#endif - -// Combined ARM. -#define RUY_DONOTUSEDIRECTLY_ARM \ - (RUY_DONOTUSEDIRECTLY_ARM_64 || RUY_DONOTUSEDIRECTLY_ARM_32) - -// Feature and capability platform detection. -// -// These are mostly sub-selections of architectures. - -// Detect NEON. Explicitly avoid emulation, or anything like it, on x86. -#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM(X86) -#define RUY_DONOTUSEDIRECTLY_NEON 1 -#else -#define RUY_DONOTUSEDIRECTLY_NEON 0 -#endif - -// Define ARM 32-bit NEON. -#define RUY_DONOTUSEDIRECTLY_NEON_32 \ - (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_32) - -// Define ARM 64-bit NEON. -// Note: NEON is implied by ARM64, so this define is redundant. -// It still allows some conveyance of intent. -#define RUY_DONOTUSEDIRECTLY_NEON_64 \ - (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64) - -// Disable X86 enhancements on __APPLE__ because b/138922878, see comment #8, we -// may only need to disable this on XCode <= 10.2. -// -// Disable when not using Clang-Linux, because too many user issues arise from -// compilation variations. -// -// NOTE: Consider guarding by !defined(__APPLE__) when removing Linux-only -// restriction. -// -// __EMSCRIPTEN__ is checked because the runtime Path resolution can use asm. -// -// The Android NDK logic excludes earlier and very broken versions of intrinsics -// headers. -#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS) || \ - (defined(__clang__) && (__clang_major__ >= 8) && defined(__linux__) && \ - !defined(__EMSCRIPTEN__) && \ - (!defined(__ANDROID_NDK__) || \ - (defined(__NDK_MAJOR__) && (__NDK_MAJOR__ >= 20)))) -#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 1 -#else -#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 0 -#endif - -// These CPU capabilities will all be true when Skylake, etc, are enabled during -// compilation. -#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && \ - defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \ - defined(__AVX512BW__) && defined(__AVX512VL__) -#define RUY_DONOTUSEDIRECTLY_AVX512 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX512 0 -#endif - -#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && defined(__AVX2__) -#define RUY_DONOTUSEDIRECTLY_AVX2 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX2 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note does not check for LZCNT or POPCNT. -#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM(X86_ENHANCEMENTS) && \ - RUY_PLATFORM(X86) && defined(__SSE4_2__) && defined(__FMA__) -#define RUY_DONOTUSEDIRECTLY_SSE42 1 -#else -#define RUY_DONOTUSEDIRECTLY_SSE42 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that defined(__AVX512VBMI2__) can be false for compilation with -// -march=cascadelake. -// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__). -#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM(AVX512) && \ - defined(__AVX512VNNI__) -#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 0 -#endif - -// Detect APPLE. -#ifdef __APPLE__ -#define RUY_DONOTUSEDIRECTLY_APPLE 1 -#else -#define RUY_DONOTUSEDIRECTLY_APPLE 0 -#endif - -// Detect Emscripten, typically Wasm. -#ifdef __EMSCRIPTEN__ -#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 1 -#else -#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 0 -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/pmu.cc b/tensorflow/lite/experimental/ruy/ruy/pmu.cc deleted file mode 100644 index 6405aa15e6a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pmu.cc +++ /dev/null @@ -1,281 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/pmu.h" - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -#ifdef __linux__ -#include -#include -#include -#include -#include - -#include -#endif - -#include -#include -#include -#include - -namespace ruy { - -// Linux-specific. Not ARM-specific. -#ifdef __linux__ -class PerfEvent { - public: - PerfEvent(std::uint32_t type, std::uint64_t config) { - perf_event_attr pe; - memset(&pe, 0, sizeof(pe)); - pe.size = sizeof(pe); - pe.type = type; - pe.config = config; - pe.disabled = 1; - pe.exclude_kernel = 1; - pe.exclude_hv = 1; - pe.inherit = 1; - fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); - if (fd_ == -1) { - fprintf(stderr, "perf_event_open failed for config 0x%lx\n", - static_cast(config)); - // abort(); - } - } - - ~PerfEvent() { - RUY_CHECK(!started_); - close(fd_); - } - - void Start() { - RUY_CHECK(!started_); - started_ = true; - ioctl(fd_, PERF_EVENT_IOC_RESET, 0); - ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0); - count_at_start_ = Read(); - } - - void Stop() { - RUY_CHECK(started_); - started_ = false; - ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0); - count_at_stop_ = Read(); - } - - std::int64_t Count() const { - RUY_CHECK(!started_); - return count_at_stop_ - count_at_start_; - } - - private: - std::int64_t Read() const { - std::int64_t count; - RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1); - return count; - } - std::int64_t count_at_start_ = -1; - std::int64_t count_at_stop_ = -1; - bool started_ = false; - int fd_ = -1; -}; -#else -// Placeholder implementation to at least compile outside of linux. -#define PERF_TYPE_RAW 0 -class PerfEvent { - public: - PerfEvent(std::uint32_t, std::uint64_t) {} - ~PerfEvent() {} - void Start() {} - void Stop() {} - std::int64_t Count() const { return 0; } -}; -#endif - -// ARM-specific. Query ARM PMU counters as Linux perf events using -// PERF_TYPE_RAW. -namespace arm_pmuv3 { - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-const-variable" - -// These event numbers are listed in the ARMv8 architecture reference manual. -constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; -constexpr std::uint16_t L1I_TLB_REFILL = 0x02; -constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; -constexpr std::uint16_t L1D_CACHE = 0x04; -constexpr std::uint16_t L1D_TLB_REFILL = 0x05; -constexpr std::uint16_t LD_RETIRED = 0x06; -constexpr std::uint16_t ST_RETIRED = 0x07; -constexpr std::uint16_t INST_RETIRED = 0x08; -constexpr std::uint16_t EXC_TAKEN = 0x09; -constexpr std::uint16_t EXC_RETURN = 0x0A; -constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; -constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; -constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; -constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; -constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; -constexpr std::uint16_t BR_MIS_PRED = 0x10; -constexpr std::uint16_t CPU_CYCLES = 0x11; -constexpr std::uint16_t BR_PRED = 0x12; -constexpr std::uint16_t MEM_ACCESS = 0x13; -constexpr std::uint16_t L1I_CACHE = 0x14; -constexpr std::uint16_t L1D_CACHE_WB = 0x15; -constexpr std::uint16_t L2D_CACHE = 0x16; -constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; -constexpr std::uint16_t L2D_CACHE_WB = 0x18; -constexpr std::uint16_t BUS_ACCESS = 0x19; -constexpr std::uint16_t MEMORY_ERROR = 0x1A; -constexpr std::uint16_t INST_SPEC = 0x1B; -constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; -constexpr std::uint16_t BUS_CYCLES = 0x1D; -constexpr std::uint16_t CHAIN = 0x1E; -constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; -constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; -constexpr std::uint16_t BR_RETIRED = 0x21; -constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; -constexpr std::uint16_t STALL_FRONTEND = 0x23; -constexpr std::uint16_t STALL_BACKEND = 0x24; -constexpr std::uint16_t L1D_TLB = 0x25; -constexpr std::uint16_t L1I_TLB = 0x26; -constexpr std::uint16_t L2I_CACHE = 0x27; -constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; -constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; -constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; -constexpr std::uint16_t L3D_CACHE = 0x2B; -constexpr std::uint16_t L3D_CACHE_WB = 0x2C; -constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; -constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; -constexpr std::uint16_t L2D_TLB = 0x2F; -constexpr std::uint16_t L2I_TLB = 0x30; -constexpr std::uint16_t LL_CACHE = 0x32; -constexpr std::uint16_t LL_CACHE_MISS = 0x33; -constexpr std::uint16_t DTLB_WALK = 0x34; -constexpr std::uint16_t LL_CACHE_RD = 0x36; -constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; - -// Additional implementation-defined events found by googling around. -constexpr std::uint16_t L1D_CACHE_RD = 0x40; -constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; -constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; -constexpr std::uint16_t L1D_TLB_RD = 0x4E; -constexpr std::uint16_t L2D_CACHE_RD = 0x50; -constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; -constexpr std::uint16_t BUS_ACCESS_RD = 0x60; -constexpr std::uint16_t MEM_ACCESS_RD = 0x66; -constexpr std::uint16_t L3D_CACHE_RD = 0xA0; -constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; - -#pragma GCC diagnostic pop - -}; // namespace arm_pmuv3 - -class PmuEventsPrivate { - public: - PmuEventsPrivate() - : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL), - l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL), - l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL), - ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS), - l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL), - l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL), - stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND), - stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND), - br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED) {} - - private: - friend class PmuEvents; - PerfEvent l1d_cache_refill; - PerfEvent l2d_cache_refill; - PerfEvent l3d_cache_refill; - PerfEvent ll_cache_miss; - PerfEvent l1d_tlb_refill; - PerfEvent l2d_tlb_refill; - PerfEvent stall_frontend; - PerfEvent stall_backend; - PerfEvent br_mis_pred; -}; - -PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {} -PmuEvents::~PmuEvents() { delete priv; } - -void PmuEvents::StartRecording() { - priv->l1d_cache_refill.Start(); - priv->l2d_cache_refill.Start(); - priv->l3d_cache_refill.Start(); - priv->ll_cache_miss.Start(); - priv->l1d_tlb_refill.Start(); - priv->l2d_tlb_refill.Start(); - priv->stall_frontend.Start(); - priv->stall_backend.Start(); - priv->br_mis_pred.Start(); -} - -void PmuEvents::StopRecording() { - priv->l1d_cache_refill.Stop(); - priv->l2d_cache_refill.Stop(); - priv->l3d_cache_refill.Stop(); - priv->ll_cache_miss.Stop(); - priv->l1d_tlb_refill.Stop(); - priv->l2d_tlb_refill.Stop(); - priv->stall_frontend.Stop(); - priv->stall_backend.Stop(); - priv->br_mis_pred.Stop(); -} - -float PmuEvents::BranchMispredictionCount() const { - return static_cast(priv->br_mis_pred.Count()); -} - -float PmuEvents::FrontendStallCount() const { - return static_cast(priv->stall_frontend.Count()); -} - -float PmuEvents::BackendStallCount() const { - return static_cast(priv->stall_backend.Count()); -} - -float PmuEvents::L1RefillCount() const { - return static_cast(priv->l1d_cache_refill.Count()); -} - -float PmuEvents::L2RefillCount() const { - return static_cast(priv->l2d_cache_refill.Count()); -} - -float PmuEvents::L3RefillCount() const { - // Important: this was discovered in the context of the above experiments, - // which also tested the _RD variants of these counters. So it's possible that - // it's just not needed here with the default (non _RD) counters. - // - // Some CPUs implement LL_CACHE_MISS[_RD], some implement - // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is - // zero, or they roughly both agree with each other. Therefore, taking the max - // of them is a reasonable way to get something more portable across various - // CPUs. - return static_cast( - std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count())); -} - -float PmuEvents::L1TLBRefillCount() const { - return static_cast(priv->l1d_tlb_refill.Count()); -} - -float PmuEvents::L2TLBRefillCount() const { - return static_cast(priv->l2d_tlb_refill.Count()); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/pmu.h b/tensorflow/lite/experimental/ruy/ruy/pmu.h deleted file mode 100644 index 721c1d5f1cc..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/pmu.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ - -namespace ruy { - -class PmuEventsPrivate; - -class PmuEvents { - public: - PmuEvents(); - ~PmuEvents(); - void StartRecording(); - void StopRecording(); - float L1RefillCount() const; - float L2RefillCount() const; - float L3RefillCount() const; - float BranchMispredictionCount() const; - float FrontendStallCount() const; - float BackendStallCount() const; - float L1TLBRefillCount() const; - float L2TLBRefillCount() const; - - private: - PmuEventsPrivate* priv = nullptr; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/prepack.h b/tensorflow/lite/experimental/ruy/ruy/prepack.h deleted file mode 100644 index 794b8df7b4d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepack.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Implementation of low-level pre-packing API. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/dispatch.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -template -void PrePackForMulInternal(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - SidePair prepacked, - std::function alloc_fn) { - profiler::ScopeLabel label("PrePackForMul"); - Path the_path = context->GetPathToTake(); - RUY_CHECK_NE(the_path, Path::kReference); - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - - const SidePair origin{0, 0}; - const SidePair rounded_dims{params.packed[Side::kLhs].layout.cols, - params.packed[Side::kRhs].layout.cols}; - - Tuning tuning = context->GetMainThreadTuning(); - for (Side side : {Side::kLhs, Side::kRhs}) { - if (prepacked[side]) { - prepacked[side]->data_size = DataSize(params.packed[side]); - prepacked[side]->sums_size = SumsSize(params.packed[side]); - prepacked[side]->data = alloc_fn(prepacked[side]->data_size); - prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size); - params.packed[side].data = prepacked[side]->data; - params.packed[side].sums = prepacked[side]->sums; - params.RunPack(side, tuning, origin[side], rounded_dims[side]); - } - } -} - -template -void MulWithPrepackedInternal(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - SidePair prepacked) { - profiler::ScopeLabel label("MulWithPrepacked"); - - EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); - EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, - dst->zero_point); - - Path the_path = context->GetPathToTake(); - RUY_CHECK_NE(the_path, Path::kReference); - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - - for (Side side : {Side::kLhs, Side::kRhs}) { - if (prepacked[side]) { - params.packed[side].data = prepacked[side]->data; - params.packed[side].sums = prepacked[side]->sums; - params.is_prepacked[side] = true; - } - } - - TrMul(¶ms, context); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc deleted file mode 100644 index da683020169..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { - -using CacheIterator = PrepackedCache::CacheIterator; - -// Looks for an entry with `key`. If found, update its time stamp. -CacheIterator PrepackedCache::FindAndUpdate(const CacheKey &key) { - auto itr = cache_.find(key); - // If found, update with new access time for this entry. - if (itr != cache_.end()) { - const TimePoint time = CacheNow(); - itr->second.second = time; - } - // std::move() is required in the MSVC STL when NDEBUG is not set, and has no - // effect in libc++. - return std::move(itr); // NOLINT -} - -void PrepackedCache::Insert(const CacheKey &key, - const PrepackedMatrix &matrix) { - // Calculate size of this new item. - const size_t size_bytes = matrix.data_size + matrix.sums_size; - - // While we are above the threshold of ejection, eject the LRU entry. - while (!cache_.empty() && - ((TotalSize() + size_bytes) > ejection_threshold_)) { - EjectOne(); - } - DoInsert(key, matrix); - cache_size_ += matrix.data_size + matrix.sums_size; -} - -void PrepackedCache::EjectOne() { - TimePoint oldest_time = CacheNow(); - auto oldest = cache_.begin(); - { - profiler::ScopeLabel label("PepackedCacheEjection"); - for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) { - if (itr->second.second < oldest_time) { - oldest_time = itr->second.second; - oldest = itr; - } - } - } - PrepackedMatrix &pmatrix = oldest->second.first; - cache_size_ -= pmatrix.data_size; - cache_size_ -= pmatrix.sums_size; - allocator_.Free(pmatrix.data); - allocator_.Free(pmatrix.sums); - cache_.erase(oldest); -} - -void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) { - pmatrix->data = allocator_.Alloc(pmatrix->data_size); - pmatrix->sums = allocator_.Alloc(pmatrix->sums_size); -} - -void PrepackedCache::DoInsert(const CacheKey &key, - const PrepackedMatrix &matrix) { - const TimePoint t = CacheNow(); - const MatrixWithTimeStamp mts({matrix, t}); - cache_.insert({key, mts}); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h deleted file mode 100644 index f2ee15559c7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { - -namespace detail { - -// Tracks a set of blocks allocated from the underlying system allocator. -class SystemBlockAllocator { - public: - void *Alloc(std::ptrdiff_t num_bytes) { - void *p = detail::SystemAlignedAlloc(num_bytes); - blocks_.push_back(p); - return p; - } - - void Free(void *block) { - for (auto it = blocks_.begin(); it != blocks_.end(); ++it) { - if (*it == block) { - detail::SystemAlignedFree(block); - blocks_.erase(it); - return; - } - } - RUY_DCHECK(false); // Trying to free pointer we did not allocate. - } - - ~SystemBlockAllocator() { - for (void *block : blocks_) { - detail::SystemAlignedFree(block); - } - } - - private: - std::vector blocks_; -}; - -} // namespace detail - -enum CachePolicy { kNoCache, kCacheLHSOnNarrowMul }; - -// "Low effort" Least Recently Used Cache for Prepacked Matrices -// A cache mechanism for prepacked matrices that ejects oldest entries. -// The implementation is "low effort" in the following ways: -// - we just linearly search for the oldest entry when doing an ejection -// - the ejection policy is very simple: if the new size would be above the -// . threshold, we will eject entries until the size is below the threshold. -// Current use cases (RNNs with GEMV operations) indicate that ejection is rare -// and memory constraints are tight, so we devote no additional storage to the -// LRU mechanism and accept O(n) search to eject oldest entry. In practice, -// the number of total entries has not been shown to be large. -// This class is not thread safe. In Ruy, memory allocation for packed matrices -// is done in a single threaded context and the actual packing activity may -// be done in a multi-threaded context. -class PrepackedCache { - public: - static constexpr int kDefaultEjectionThresholdBytes = 1 << 28; - - using CacheKey = std::pair; - - using MatrixWithTimeStamp = std::pair; - - using CacheIterator = std::map::const_iterator; - - using AlignedAllocator = detail::AlignedAllocator; - - explicit PrepackedCache( - int32_t ejection_threshold = kDefaultEjectionThresholdBytes) - : ejection_threshold_(ejection_threshold), cache_size_(0) {} - - // Looks for an entry with `key`. If found, update its time stamp. - CacheIterator FindAndUpdate(const CacheKey &key); - - // Returns end iterator for internal cache. The iterator type is appropriate - // to use with `FindAndUpdate`. - CacheIterator cend() const { return cache_.end(); } - - // Returns the total size (in bytes) of data held in this cache. - int TotalSize() const { return cache_size_; } - - // All calls to get current TimePoints go through here. - // TODO(b/145625614) Profile timestamps on relevant models to see if - // this level of granularity is sufficient. CoarseNow is cheap so - // it would be nice to keep it. - TimePoint CacheNow() const { return CoarseNow(); } - - // Performs the memory allocation for the `data` and `sums` members of a - // PrepackedMatrix. - void AllocatePrepackedMatrix(PrepackedMatrix *pmatrix); - - // Adds the PrepackedMatrix to the cache, possibly ejecting other values. - void Insert(const CacheKey &key, const PrepackedMatrix &matrix); - - private: - void EjectOne(); - void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix); - detail::SystemBlockAllocator allocator_; - std::map cache_; - const int32_t ejection_threshold_; - size_t cache_size_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc b/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc deleted file mode 100644 index 453190a3b88..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/prepacked_cache_test.cc +++ /dev/null @@ -1,210 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/prepacked_cache.h" - -#include // NOLINT(build/c++11) - -#include -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { -namespace { - -TEST(PrepackedCacheTest, TestCacheEjection) { - // Create the cache. - PrepackedCache prepacked_cache(32); - // Allocate the prepacked matrix. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - // Get a time point after the insertion into the cache. - TimePoint current = CoarseNow(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - PrepackedCache::CacheIterator itr = prepacked_cache.FindAndUpdate(cache_key1); - EXPECT_NE(itr, prepacked_cache.cend()); - // By finding mat1, we updated its timestamp. Verify that `current` is older - // than the time stamp now associated with mat1. - EXPECT_LT(current, itr->second.second); - PrepackedMatrix mat2; - mat2.data_size = 8; - mat2.sums_size = 4; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - - auto cache_key2 = std::make_pair(nullptr, mat2.data); - prepacked_cache.Insert(cache_key2, mat2); - // The cache size was exceeded by inserting mat2. Ensure that mat1 was - // ejected. - EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheBasic) { - // Create the cache. - PrepackedCache prepacked_cache(48); - // Allocate the prepacked matrix. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - - PrepackedMatrix mat2; - mat2.data_size = 8; - mat2.sums_size = 4; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - - auto cache_key2 = std::make_pair(nullptr, mat2.data); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - prepacked_cache.Insert(cache_key2, mat2); - // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not - // ejected. - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheEjection2) { - // Create the cache. - PrepackedCache prepacked_cache(73); - // Allocate the prepacked matrix 1. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 2. - PrepackedMatrix mat2; - mat2.data_size = 16; - mat2.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - auto cache_key2 = std::make_pair(nullptr, mat2.data); - prepacked_cache.Insert(cache_key2, mat2); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 3. - PrepackedMatrix mat31; - mat31.data_size = 16; - mat31.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat31); - auto cache_key3 = std::make_pair(nullptr, mat31.data); - prepacked_cache.Insert(cache_key3, mat31); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // The next insertion will cause the cache size to go over the ejection - // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 4. - PrepackedMatrix mat4; - mat4.data_size = 16; - mat4.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat4); - auto cache_key4 = std::make_pair(nullptr, mat4.data); - prepacked_cache.Insert(cache_key4, mat4); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not. - EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key2), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheOnCacheable) { - // Create context and set the cache policy - ruy::Context context; - context.cache_policy = ruy::kCacheLHSOnNarrowMul; - PrepackedCache* cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); - - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - // Perform the multiplication and confirm no caching occurred. - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_EQ(cache->TotalSize(), 0); - - // Set cacheable for the LHS, repeat the multiplication, and see - // that caching did occur. - lhs.cacheable = true; - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_NE(cache->TotalSize(), 0); -} - -TEST(PrepackedCacheTest, TestClearCache) { - // Create context and set the cache policy - ruy::Context context; - context.cache_policy = ruy::kCacheLHSOnNarrowMul; - PrepackedCache* cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); - - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - // Set cacheable for the LHS and see that caching occurs. - lhs.cacheable = true; - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_NE(cache->TotalSize(), 0); - - // Clear the cache via the Context. - context.ClearPrepackedCache(); - // Verify that the cache is now empty. - cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/BUILD b/tensorflow/lite/experimental/ruy/ruy/profiler/BUILD deleted file mode 100644 index 5e9d9bd3bae..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -# A minimalistic profiler sampling pseudo-stacks - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -config_setting( - name = "ruy_profiler", - define_values = {"ruy_profiler": "true"}, -) - -# Used to build TFLite Micro RUY dependency for embedded targets outside of the -# RUY source tree. -filegroup( - name = "ruy_instrumentation_header", - srcs = ["instrumentation.h"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "instrumentation", - srcs = ["instrumentation.cc"], - hdrs = ["instrumentation.h"], - defines = select({ - ":ruy_profiler": ["RUY_PROFILER"], - "//conditions:default": [], - }), -) - -cc_library( - name = "profiler", - srcs = [ - "profiler.cc", - "treeview.cc", - ], - hdrs = [ - "profiler.h", - "treeview.h", - ], - deps = [":instrumentation"], -) - -cc_library( - name = "test_instrumented_library", - testonly = True, - srcs = ["test_instrumented_library.cc"], - hdrs = ["test_instrumented_library.h"], - deps = [":instrumentation"], -) - -cc_test( - name = "test", - srcs = ["test.cc"], - deps = [ - ":profiler", - ":test_instrumented_library", - "@com_google_googletest//:gtest", - ], -) diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/README.md b/tensorflow/lite/experimental/ruy/ruy/profiler/README.md deleted file mode 100644 index 8d7902566b3..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/README.md +++ /dev/null @@ -1,149 +0,0 @@ -# A minimalistic profiler sampling pseudo-stacks - -## Overview - -The present directory is the "ruy profiler". As a time profiler, it allows to -measure where code is spending time. - -Contrary to most typical profilers, what it samples is not real call stacks, but -"pseudo-stacks" which are just simple data structures constructed from within -the program being profiled. Using this profiler requires manually instrumenting -code to construct such pseudo-stack information. - -Another unusual characteristic of this profiler is that it uses only the C++11 -standard library. It does not use any non-portable feature, in particular it -does not rely on signal handlers. The sampling is performed by a thread, the -"profiler thread". - -A discussion of pros/cons of this approach is appended below. - -## How to use this profiler - -### How to instrument code - -An example of instrumented code is given in `test_instrumented_library.cc`. - -Code is instrumented by constructing `ScopeLabel` objects. These are RAII -helpers, ensuring that the thread pseudo-stack contains the label during their -lifetime. In the most common use case, one would construct such an object at the -start of a function, so that its scope is the function scope and it allows to -measure how much time is spent in this function. - -```c++ -#include "ruy/profiler/instrumentation.h" - -... - -void SomeFunction() { - ruy::profiling::ScopeLabel function_label("SomeFunction"); - ... do something ... -} -``` - -A `ScopeLabel` may however have any scope, for instance: - -```c++ -if (some_case) { - ruy::profiling::ScopeLabel extra_work_label("Some more work"); - ... do some more work ... -} -``` - -The string passed to the `ScopeLabel` constructor must be just a pointer to a -literal string (a `char*` pointer). The profiler will assume that these pointers -stay valid until the profile is finalized. - -However, that literal string may be a `printf` format string, and labels may -have up to 4 parameters, of type `int`. For example: - -```c++ -void SomeFunction(int size) { - ruy::profiling::ScopeLabel function_label("SomeFunction (size=%d)", size); - -``` - -### How to run the profiler - -Profiling instrumentation is a no-op unless the preprocessor token -`RUY_PROFILER` is defined, so defining it is the first step when actually -profiling. When building with Bazel, the preferred way to enable that is to pass -this flag on the Bazel command line: - -``` ---define=ruy_profiler=true -``` - -To actually profile a code scope, it is enough to construct a `ScopeProfile` -object, also a RAII helper. It will start the profiler on construction, and on -destruction it will terminate the profiler and report the profile treeview on -standard output by default. Example: - -```c++ -void SomeProfiledBenchmark() { - ruy::profiling::ScopeProfile profile; - - CallSomeInstrumentedCode(); -} -``` - -An example is provided by the `:test` target in the present directory. Run it -with `--define=ruy_profiler=true` as explained above: - -``` -bazel run -c opt \ - --define=ruy_profiler=true \ - //tensorflow/lite/experimental/ruy/profiler:test -``` - -The default behavior dumping the treeview on standard output may be overridden -by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor. -This causes the tree-view to be stored in that `TreeView` object, where it may -be accessed an manipulated using the functions declared in `treeview.h`. The -aforementioned `:test` provides examples for doing so. - -## Advantages and inconvenients - -Compared to a traditional profiler, e.g. Linux's "perf", the present kind of -profiler has the following inconvenients: - -* Requires manual instrumentation of code being profiled. -* Substantial overhead, modifying the performance characteristics of the code - being measured. -* Questionable accuracy. - -But also the following advantages: - -* Profiling can be driven from within a benchmark program, allowing the entire - profiling procedure to be a single command line. -* Not relying on symbol information removes removes exposure to toolchain - details and means less hassle in some build environments, especially - embedded/mobile (single command line to run and profile, no symbols files - required). -* Fully portable (all of this is standard C++11). -* Fully testable (see `:test`). Profiling becomes just another feature of the - code like any other. -* Customized instrumentation can result in easier to read treeviews (only - relevant functions, and custom labels may be more readable than function - names). -* Parametrized/formatted labels allow to do things that aren't possible with - call-stack-sampling profilers. For example, break down a profile where much - time is being spent in matrix multiplications, by the various matrix - multiplication shapes involved. - -The philosophy underlying this profiler is that software performance depends on -software engineers profiling often, and a key factor limiting that in practice -is the difficulty or cumbersome aspects of profiling with more serious profilers -such as Linux's "perf", especially in embedded/mobile development: multiple -command lines are involved to copy symbol files to devices, retrieve profile -data from the device, etc. In that context, it is useful to make profiling as -easy as benchmarking, even on embedded targets, even if the price to pay for -that is lower accuracy, higher overhead, and some intrusive instrumentation -requirement. - -Another key aspect determining what profiling approach is suitable for a given -context, is whether one already has a-priori knowledge of where much of the time -is likely being spent. When one has such a-priori knowledge, it is feasible to -instrument the known possibly-critical code as per the present approach. On the -other hand, in situations where one doesn't have such a-priori knowledge, a real -profiler such as Linux's "perf" allows to right away get a profile of real -stacks, from just symbol information generated by the toolchain. diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc deleted file mode 100644 index b7c330c04bd..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -#ifdef RUY_PROFILER - -namespace ruy { -namespace profiler { - -void Label::operator=(const Label& other) { - format_ = other.format_; - args_count_ = other.args_count_; - for (int i = 0; i < args_count_; i++) { - args_[i] = other.args_[i]; - } -} - -bool Label::operator==(const Label& other) const { - if (std::string(format_) != std::string(other.format_)) { - return false; - } - if (args_count_ != other.args_count_) { - return false; - } - for (int i = 0; i < args_count_; i++) { - if (args_[i] != other.args_[i]) { - return false; - } - } - return true; -} - -std::string Label::Formatted() const { - static constexpr int kBufSize = 256; - char buf[kBufSize]; - if (args_count_ == 0) { - return format_; - } - if (args_count_ == 1) { - snprintf(buf, kBufSize, format_, args_[0]); - } else if (args_count_ == 2) { - snprintf(buf, kBufSize, format_, args_[0], args_[1]); - } else if (args_count_ == 3) { - snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]); - } else if (args_count_ == 4) { - snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]); - } else { - abort(); - } - return buf; -} - -namespace detail { - -std::mutex* GlobalsMutex() { - static std::mutex mutex; - return &mutex; -} - -bool& GlobalIsProfilerRunning() { - static bool b; - return b; -} - -std::vector* GlobalAllThreadStacks() { - static std::vector all_stacks; - return &all_stacks; -} - -ThreadStack* ThreadLocalThreadStack() { - thread_local static ThreadStack thread_stack; - return &thread_stack; -} - -ThreadStack::ThreadStack() { - std::lock_guard lock(*GlobalsMutex()); - static std::uint32_t global_next_thread_stack_id = 0; - stack_.id = global_next_thread_stack_id++; - GlobalAllThreadStacks()->push_back(this); -} - -ThreadStack::~ThreadStack() { - std::lock_guard lock(*GlobalsMutex()); - std::vector* all_stacks = GlobalAllThreadStacks(); - for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) { - if (*it == this) { - all_stacks->erase(it); - return; - } - } -} -int GetBufferSize(const Stack& stack) { - return sizeof(stack.id) + sizeof(stack.size) + - stack.size * sizeof(stack.labels[0]); -} - -void CopyToBuffer(const Stack& stack, char* dst) { - memcpy(dst, &stack.id, sizeof(stack.id)); - dst += sizeof(stack.id); - memcpy(dst, &stack.size, sizeof(stack.size)); - dst += sizeof(stack.size); - memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0])); -} - -void ReadFromBuffer(const char* src, Stack* stack) { - memcpy(&stack->id, src, sizeof(stack->id)); - src += sizeof(stack->id); - memcpy(&stack->size, src, sizeof(stack->size)); - src += sizeof(stack->size); - memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0])); -} - -} // namespace detail -} // namespace profiler -} // namespace ruy - -#endif diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h b/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h deleted file mode 100644 index a9046d465af..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ - -#ifdef RUY_PROFILER -#include -#include -#include -#endif - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -// A label is how a code scope is annotated to appear in profiles. -// The stacks that are sampled by the profiler are stacks of such labels. -// A label consists of a literal string, plus optional integer arguments. -class Label { - public: - Label() {} - template - explicit Label(Args... args) { - Set(args...); - } - void Set(const char* format) { - format_ = format; - args_count_ = 0; - } - template - void Set(const char* format, Args... args) { - format_ = format; - args_count_ = sizeof...(args); - SetArgs(0, args...); - } - - void operator=(const Label& other); - - bool operator==(const Label& other) const; - - std::string Formatted() const; - const char* format() const { return format_; } - - private: - void SetArgs(int position, int arg0) { args_[position] = arg0; } - - template - void SetArgs(int position, int arg0, Args... args) { - SetArgs(position, arg0); - SetArgs(position + 1, args...); - } - - static constexpr int kMaxArgs = 4; - const char* format_ = nullptr; - int args_count_ = 0; - int args_[kMaxArgs]; -}; - -namespace detail { - -// Forward-declaration, see class ThreadStack below. -class ThreadStack; - -bool& GlobalIsProfilerRunning(); - -// Returns the global vector of pointers to all stacks, there being one stack -// per thread executing instrumented code. -std::vector* GlobalAllThreadStacks(); - -// Returns the mutex to be locked around any access to GlobalAllThreadStacks(). -std::mutex* GlobalsMutex(); - -// Returns the thread-local stack, specific to the current thread. -ThreadStack* ThreadLocalThreadStack(); - -// This 'stack' is what may be more appropriately called a 'pseudostack': -// It contains Label entries that are 'manually' entered by instrumentation -// code. It's unrelated to real call stacks. -struct Stack { - std::uint32_t id = 0; - static constexpr int kMaxSize = 64; - int size = 0; - Label labels[kMaxSize]; -}; - -// Returns the buffer byte size required by CopyToSample. -int GetBufferSize(const Stack& stack); - -// Copies this Stack into a byte buffer, called a 'sample'. -void CopyToBuffer(const Stack& stack, char* dst); - -// Populates this Stack from an existing sample buffer, typically -// produced by CopyToSample. -void ReadFromBuffer(const char* src, Stack* stack); - -// ThreadStack is meant to be used as a thread-local singleton, assigning to -// each thread a Stack object holding its pseudo-stack of profile labels, -// plus a mutex allowing to synchronize accesses to this pseudo-stack between -// this thread and a possible profiler thread sampling it. -class ThreadStack { - public: - ThreadStack(); - ~ThreadStack(); - - const Stack& stack() const { return stack_; } - - // Returns the mutex to lock around any access to this stack. Each stack is - // accessed by potentially two threads: the thread that it belongs to - // (which calls Push and Pop) and the profiler thread during profiling - // (which calls CopyToSample). - std::mutex& Mutex() const { return mutex_; } - - // Pushes a new label on the top of this Stack. - template - void Push(Args... args) { - // This mutex locking is needed to guard against race conditions as both - // the current thread and the profiler thread may be concurrently accessing - // this stack. In addition to that, this mutex locking also serves the other - // purpose of acting as a barrier (of compiler code reordering, of runtime - // CPU instruction reordering, and of memory access reordering), which - // gives a measure of correctness to this profiler. The downside is some - // latency. As this lock will be uncontended most of the times, the cost - // should be roughly that of an sequentially-consistent atomic access, - // comparable to an access to the level of CPU data cache that is shared - // among all cores, typically 60 cycles on current ARM CPUs, plus side - // effects from barrier instructions. - std::lock_guard lock(mutex_); - // Avoid overrunning the stack, even in 'release' builds. This profiling - // instrumentation code should not ship in release builds anyway, the - // overhead of this check is negligible, and overrunning a stack array would - // be bad. - if (stack_.size >= Stack::kMaxSize) { - abort(); - } - stack_.labels[stack_.size++].Set(args...); - } - - // Pops the top-most label from this Stack. - void Pop() { - // See the comment in Push about this lock. While it would be tempting to - // try to remove this lock and just atomically decrement size_ with a - // store-release, that would not necessarily be a substitute for all of the - // purposes that this lock serves, or if it was done carefully to serve all - // of the same purposes, then that wouldn't be faster than this (mostly - // uncontended) lock. - std::lock_guard lock(mutex_); - stack_.size--; - } - - private: - mutable std::mutex mutex_; - Stack stack_; -}; - -} // namespace detail - -// RAII user-facing way to construct Labels associated with their life scope -// and get them pushed to / popped from the current thread stack. -class ScopeLabel { - public: - template - ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) { - thread_stack_->Push(args...); - } - - ~ScopeLabel() { thread_stack_->Pop(); } - - private: - detail::ThreadStack* thread_stack_; -}; - -#else // no RUY_PROFILER - -class ScopeLabel { - public: - template - explicit ScopeLabel(Args...) {} - - // This destructor is needed to consistently silence clang's -Wunused-variable - // which seems to trigger semi-randomly. - ~ScopeLabel() {} -}; - -#endif - -} // namespace profiler -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc deleted file mode 100644 index c5ff598ee2b..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" - -#ifdef RUY_PROFILER -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include -#endif - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -ScopeProfile::ScopeProfile() { Start(); } -ScopeProfile::ScopeProfile(bool enable) { - if (enable) { - Start(); - } -} -ScopeProfile::~ScopeProfile() { - if (!thread_) { - return; - } - finishing_.store(true); - thread_->join(); - Finish(); -} - -void ScopeProfile::Start() { - { - std::lock_guard lock(*detail::GlobalsMutex()); - if (detail::GlobalIsProfilerRunning()) { - fprintf(stderr, "FATAL: profiler already running!\n"); - abort(); - } - detail::GlobalIsProfilerRunning() = true; - } - finishing_ = false; - thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this)); -} - -void ScopeProfile::ThreadFunc() { - while (!finishing_.load()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - std::lock_guard lock(*detail::GlobalsMutex()); - auto* thread_stacks = detail::GlobalAllThreadStacks(); - for (detail::ThreadStack* thread_stack : *thread_stacks) { - Sample(*thread_stack); - } - } -} - -void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) { - std::lock_guard lock(thread_stack.Mutex()); - // Drop empty stacks. - // This ensures that profiles aren't polluted by uninteresting threads. - if (thread_stack.stack().size == 0) { - return; - } - int sample_size = detail::GetBufferSize(thread_stack.stack()); - int old_buf_size = samples_buf_.size(); - samples_buf_.resize(old_buf_size + sample_size); - detail::CopyToBuffer(thread_stack.stack(), - samples_buf_.data() + old_buf_size); -} - -void ScopeProfile::Finish() { - { - std::lock_guard lock(*detail::GlobalsMutex()); - if (!detail::GlobalIsProfilerRunning()) { - fprintf(stderr, "FATAL: profiler is not running!\n"); - abort(); - } - detail::GlobalIsProfilerRunning() = false; - } - if (user_treeview_) { - user_treeview_->Populate(samples_buf_); - } else { - TreeView treeview; - treeview.Populate(samples_buf_); - Print(treeview); - } -} - -#endif // RUY_PROFILER - -} // namespace profiler -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h b/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h deleted file mode 100644 index 19ef0deba0c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ - -#include - -#ifdef RUY_PROFILER -#include -#include -#include -#include -#endif - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -// RAII user-facing way to create a profiler and let it profile a code scope, -// and print out an ASCII/MarkDown treeview upon leaving the scope. -class ScopeProfile { - public: - // Default constructor, unconditionally profiling. - ScopeProfile(); - - // Constructor allowing to choose at runtime whether to profile. - explicit ScopeProfile(bool enable); - - // Destructor. It's where the profile is reported. - ~ScopeProfile(); - - // See treeview_ member. - void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; } - - private: - void Start(); - - // Thread entry point function for the profiler thread. This thread is - // created on construction. - void ThreadFunc(); - - // Record a stack as a sample. - void Sample(const detail::ThreadStack& stack); - - // Finalize the profile. Called on destruction. - // If user_treeview_ is non-null, it will receive the treeview. - // Otherwise the treeview will just be printed. - void Finish(); - - // Buffer where samples are recorded during profiling. - std::vector samples_buf_; - - // Used to synchronize thread termination. - std::atomic finishing_; - - // Underlying profiler thread, which will perform the sampling. - // This profiler approach relies on a thread rather than on signals. - std::unique_ptr thread_; - - // TreeView to populate upon destruction. If left null (the default), - // a temporary treeview will be used and dumped on stdout. The user - // may override that by passing their own TreeView object for other - // output options or to directly inspect the TreeView. - TreeView* user_treeview_ = nullptr; -}; - -#else // no RUY_PROFILER - -struct ScopeProfile { - ScopeProfile() { -#ifdef GEMMLOWP_PROFILING - fprintf( - stderr, - "\n\n\n**********\n\nWARNING:\n\nLooks like you defined " - "GEMMLOWP_PROFILING, but this code has been ported to the new ruy " - "profiler replacing the old gemmlowp profiler. You should now be " - "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using " - "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n"); -#endif - } - explicit ScopeProfile(bool) {} -}; - -#endif - -} // namespace profiler -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc deleted file mode 100644 index feab967c87c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/test.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -namespace ruy { -namespace profiler { -namespace { - -void DoSomeMergeSort(int size) { - std::vector data(size); - - std::default_random_engine engine; - for (auto& val : data) { - val = engine(); - } - - MergeSort(size, data.data()); -} - -// The purpose of this basic test is to cover the basic path that will be taken -// by a majority of users, not inspecting treeviews but just implicitly printing -// them on stdout, and to have this test enabled even when RUY_PROFILER is not -// defined, so that we have coverage for the non-RUY_PROFILER case. -TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) { - { - ScopeProfile profile; - DoSomeMergeSort(1 << 20); - } -} - -#ifdef RUY_PROFILER - -TEST(ProfilerTest, MergeSortSingleThread) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - DoSomeMergeSort(1 << 20); - } - Print(treeview); - EXPECT_EQ(treeview.thread_roots().size(), 1); - const auto& thread_root = *treeview.thread_roots().begin()->second; - EXPECT_EQ(DepthOfTreeBelow(thread_root), 22); - EXPECT_GE( - WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"), - 0.1 * thread_root.weight); - EXPECT_GE(WeightBelowNodeMatchingFormatted( - thread_root, "MergeSortRecurse (level=20, size=1)"), - 0.01 * thread_root.weight); - - TreeView treeview_collapsed; - CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)", - &treeview_collapsed); - Print(treeview_collapsed); - const auto& collapsed_thread_root = - *treeview_collapsed.thread_roots().begin()->second; - EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6); - EXPECT_EQ( - WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"), - WeightBelowNodeMatchingUnformatted(collapsed_thread_root, - "MergeSort (size=%d)")); -} - -TEST(ProfilerTest, MemcpyFourThreads) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - std::vector> threads; - for (int i = 0; i < 4; i++) { - threads.emplace_back(new std::thread([i]() { - ScopeLabel thread_label("worker thread #%d", i); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - ScopeLabel some_more_work_label("some more work"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - })); - } - for (int i = 0; i < 4; i++) { - threads[i]->join(); - } - } - Print(treeview); - // Since we cleared GlobalAllThreadStacks and the current thread hasn't - // created any ScopeLabel, only the 4 worker threads should be recorded. - EXPECT_EQ(treeview.thread_roots().size(), 4); - for (const auto& thread_root : treeview.thread_roots()) { - const TreeView::Node& root_node = *thread_root.second; - // The root node may have 1 or 2 children depending on whether there is - // an "[other]" child. - EXPECT_GE(root_node.children.size(), 1); - EXPECT_LE(root_node.children.size(), 2); - const TreeView::Node& child_node = *root_node.children[0]; - EXPECT_EQ(child_node.label.format(), "worker thread #%d"); - // There must be 2 children, since roughly half the time will be in - // "some more work" leaving the other half in "[other]". - EXPECT_EQ(child_node.children.size(), 2); - const TreeView::Node& child_child_node = *child_node.children[0]; - // Since we sample every millisecond and the threads run for >= 2000 - // milliseconds, the "thread func" label should get roughly 2000 samples. - // Not very rigorous, as we're depending on the profiler thread getting - // scheduled, so to avoid this test being flaky, we use a much more - // conservative value of 500, one quarter of that normal value 2000. - EXPECT_GE(child_node.weight, 500); - // Likewise, allow up to four times more than the normal value 2000. - EXPECT_LE(child_node.weight, 8000); - // Roughly half of time should be spent under the "some more work" label. - float some_more_work_percentage = - 100.f * child_child_node.weight / child_node.weight; - EXPECT_GE(some_more_work_percentage, 40.0f); - EXPECT_LE(some_more_work_percentage, 60.0f); - } -} - -TEST(ProfilerTest, OneThreadAfterAnother) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - { - std::thread thread([]() { - ScopeLabel thread_label("thread 0"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - }); - thread.join(); - } - { - std::thread thread([]() { - ScopeLabel thread_label("thread 1"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - }); - thread.join(); - } - } - Print(treeview); - EXPECT_EQ(treeview.thread_roots().size(), 2); -} - -#endif // RUY_PROFILER - -} // namespace -} // namespace profiler -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc deleted file mode 100644 index e9b5929c9b7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace { - -void MergeSortRecurse(int level, int size, int* data, int* workspace) { - ruy::profiler::ScopeLabel function_label( - "MergeSortRecurse (level=%d, size=%d)", level, size); - if (size <= 1) { - return; - } - int half_size = size / 2; - MergeSortRecurse(level + 1, half_size, data, workspace); - MergeSortRecurse(level + 1, size - half_size, data + half_size, - workspace + half_size); - - ruy::profiler::ScopeLabel merging_sorted_halves_label( - "Merging sorted halves"); - int dst_index = 0; - int left_index = 0; - int right_index = half_size; - while (dst_index < size) { - int val; - if (left_index < half_size && - ((right_index >= size) || data[left_index] < data[right_index])) { - val = data[left_index++]; - } else { - val = data[right_index++]; - } - workspace[dst_index++] = val; - } - for (int i = 0; i < size; i++) { - data[i] = workspace[i]; - } -} - -} // namespace - -void MergeSort(int size, int* data) { - ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size); - std::vector workspace(size); - MergeSortRecurse(0, size, data, workspace.data()); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h b/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h deleted file mode 100644 index d6a80a09042..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/test_instrumented_library.h +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -void MergeSort(int size, int* data); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc deleted file mode 100644 index 256d2a1106c..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef RUY_PROFILER - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h" - -#include -#include -#include -#include -#include - -namespace ruy { -namespace profiler { - -namespace { - -void SortNode(TreeView::Node* node) { - using NodePtr = std::unique_ptr; - std::sort(node->children.begin(), node->children.end(), - [](const NodePtr& n1, const NodePtr& n2) { - return n1->weight > n2->weight; - }); - for (const auto& child : node->children) { - SortNode(child.get()); - } -} - -// Records a stack i.e. a sample in a treeview, by incrementing the weights -// of matching existing nodes and/or by creating new nodes as needed, -// recursively, below the given node. -void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) { - node->weight++; - if (stack.size == level) { - return; - } - TreeView::Node* child_to_add_to = nullptr; - for (const auto& child : node->children) { - if (child->label == stack.labels[level]) { - child_to_add_to = child.get(); - break; - } - } - if (!child_to_add_to) { - child_to_add_to = node->children.emplace_back(new TreeView::Node).get(); - child_to_add_to->label = stack.labels[level]; - } - AddStack(stack, child_to_add_to, level + 1); -} - -// Recursively populates the treeview below the given node with 'other' -// entries documenting for each node the difference between its weight and the -// sum of its children's weight. -void AddOther(TreeView::Node* node) { - int top_level_children_weight = 0; - for (const auto& child : node->children) { - AddOther(child.get()); - top_level_children_weight += child->weight; - } - if (top_level_children_weight != 0 && - top_level_children_weight != node->weight) { - const auto& new_child = node->children.emplace_back(new TreeView::Node); - new_child->label = Label("[other]"); - new_child->weight = node->weight - top_level_children_weight; - } -} - -} // namespace - -void TreeView::Populate(const std::vector& samples_buf_) { - thread_roots_.clear(); - // Populate the treeview with regular nodes coming from samples. - const char* buf_ptr = samples_buf_.data(); - const char* const buf_ptr_end = buf_ptr + samples_buf_.size(); - while (buf_ptr < buf_ptr_end) { - detail::Stack stack; - detail::ReadFromBuffer(buf_ptr, &stack); - // Empty stacks should have been dropped during sampling. - assert(stack.size > 0); - buf_ptr += GetBufferSize(stack); - const int id = stack.id; - if (!thread_roots_[id]) { - thread_roots_[id].reset(new Node); - } - AddStack(stack, thread_roots_[id].get(), 0); - } - // Populate the treeview with additional 'other' nodes, sort, and set - // root labels. - for (const auto& thread_root : thread_roots_) { - std::uint32_t id = thread_root.first; - Node* root = thread_root.second.get(); - AddOther(root); - SortNode(root); - root->label.Set("Thread %x (%d samples)", id, root->weight); - } -} - -// Recursively prints the treeview below the given node. The 'root' node -// argument is only needed to compute weights ratios, with the root ratio -// as denominator. -void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root, - int level) { - if (&node == &root) { - printf("%s\n\n", node.label.Formatted().c_str()); - } else { - for (int i = 1; i < level; i++) { - printf(" "); - } - printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight, - node.label.Formatted().c_str()); - } - for (const auto& child : node.children) { - PrintTreeBelow(*child, root, level + 1); - } -} - -void Print(const TreeView& treeview) { - printf("\n"); - printf("Profile (%d threads):\n\n", - static_cast(treeview.thread_roots().size())); - for (const auto& thread_root : treeview.thread_roots()) { - const TreeView::Node& root = *thread_root.second; - PrintTreeBelow(root, root, 0); - printf("\n"); - } -} - -int DepthOfTreeBelow(const TreeView::Node& node) { - if (node.children.empty()) { - return 0; - } else { - int max_child_depth = 0; - for (const auto& child : node.children) { - max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child)); - } - return 1 + max_child_depth; - } -} - -int WeightBelowNodeMatchingFunction( - const TreeView::Node& node, - const std::function& match) { - int weight = 0; - if (match(node.label)) { - weight += node.weight; - } - for (const auto& child : node.children) { - weight += WeightBelowNodeMatchingFunction(*child, match); - } - return weight; -} - -int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, - const std::string& format) { - return WeightBelowNodeMatchingFunction( - node, [&format](const Label& label) { return label.format() == format; }); -} - -int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, - const std::string& formatted) { - return WeightBelowNodeMatchingFunction( - node, [&formatted](const Label& label) { - return label.Formatted() == formatted; - }); -} - -void CollapseNode(const TreeView::Node& node_in, int depth, - TreeView::Node* node_out) { - node_out->label = node_in.label; - node_out->weight = node_in.weight; - node_out->children.clear(); - if (depth > 0) { - for (const auto& child_in : node_in.children) { - auto* child_out = new TreeView::Node; - node_out->children.emplace_back(child_out); - CollapseNode(*child_in, depth - 1, child_out); - } - } -} - -void CollapseSubnodesMatchingFunction( - const TreeView::Node& node_in, int depth, - const std::function& match, TreeView::Node* node_out) { - if (match(node_in.label)) { - CollapseNode(node_in, depth, node_out); - } else { - node_out->label = node_in.label; - node_out->weight = node_in.weight; - node_out->children.clear(); - - for (const auto& child_in : node_in.children) { - auto* child_out = new TreeView::Node; - node_out->children.emplace_back(child_out); - CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out); - } - } -} - -void CollapseNodesMatchingFunction( - const TreeView& treeview_in, int depth, - const std::function& match, TreeView* treeview_out) { - treeview_out->mutable_thread_roots()->clear(); - for (const auto& thread_root_in : treeview_in.thread_roots()) { - std::uint32_t id = thread_root_in.first; - const auto& root_in = *thread_root_in.second; - auto* root_out = new TreeView::Node; - treeview_out->mutable_thread_roots()->emplace(id, root_out); - CollapseSubnodesMatchingFunction(root_in, depth, match, root_out); - } -} - -void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, - const std::string& format, - TreeView* treeview_out) { - CollapseNodesMatchingFunction( - treeview_in, depth, - [&format](const Label& label) { return label.format() == format; }, - treeview_out); -} - -void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, - const std::string& formatted, - TreeView* treeview_out) { - CollapseNodesMatchingFunction( - treeview_in, depth, - [&formatted](const Label& label) { - return label.Formatted() == formatted; - }, - treeview_out); -} - -} // namespace profiler -} // namespace ruy - -#endif // RUY_PROFILER diff --git a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h b/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h deleted file mode 100644 index 7f48af5ece0..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/profiler/treeview.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ - -#ifdef RUY_PROFILER - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" - -namespace ruy { -namespace profiler { - -// A tree view of a profile. -class TreeView { - public: - struct Node { - std::vector> children; - Label label; - int weight = 0; - }; - - void Populate(const std::vector& samples_buf_); - - // Intentionally an *ordered* map so that threads are enumerated - // in an order that's consistent and typically putting the 'main thread' - // first. - using ThreadRootsMap = std::map>; - - const ThreadRootsMap& thread_roots() const { return thread_roots_; } - ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; } - - private: - ThreadRootsMap thread_roots_; -}; - -/* Below are API functions for manipulating and printing treeviews. */ - -// Prints the treeview to stdout. -void Print(const TreeView& treeview); - -// Prints the treeview below the given node on stdout. -void PrintTreeBelow(const TreeView::Node& node); - -// Returns the tree depth below the given node. -int DepthOfTreeBelow(const TreeView::Node& node); - -// Returns the sum of weights of nodes below the given node and filtered by -// the `match` predicate. -int WeightBelowNodeMatchingFunction( - const TreeView::Node& node, const std::function& match); - -// Returns the sum of weights of nodes below the given node and whose -// unformatted label (i.e. raw format string) matches the given `format` string. -// -// This allows to aggregate nodes whose labels differ only by parameter values. -int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, - const std::string& format); - -// Returns the sum of weights of nodes below the given node and whose formatted -// label matches the `formatted` string. -// -// In the case of nodes with parametrized labels, this allows to count only -// nodes with specific parameter values. For that purpose, one may also instead -// use WeightBelowNodeMatchingFunction directly, with a `match` predicate -// comparing raw integer parameter values directly, instead of going through -// formatted strings. -int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, - const std::string& formatted); - -// Produces a `node_out` that is a copy of `node_in` but with tree depth below -// it clamped at `depth`, with further subtrees aggregated into single leaf -// nodes. -void CollapseNode(const TreeView::Node& node_in, int depth, - TreeView::Node* node_out); - -// Calls CollapseNode with the given `depth` on every subnode filtered by the -// `match` predicate. Note that this does NOT limit the tree depth below -// `node_out` to `depth`, since each collapsed node below `node_out` may be -// arbitrarily far below it and `depth` is only used as the collapsing depth -// at that point. -void CollapseSubnodesMatchingFunction( - const TreeView::Node& node_in, int depth, - const std::function& match, TreeView::Node* node_out); - -// Calls CollapseNode with the given `depth` on every node filtered by the -// `match` predicate. Note that this does NOT limit the tree depth below -// `node_out` to `depth`, since each collapsed node below `node_out` may be -// arbitrarily far below it and `depth` is only used as the collapsing depth -// at that point. -void CollapseNodesMatchingFunction( - const TreeView& treeview_in, int depth, - const std::function& match, TreeView* treeview_out); - -// Special case of CollapseNodesMatchingFunction matching unformatted labels, -// i.e. raw format strings. -// See the comment on WeightBelowNodeMatchingUnformatted. -void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, - const std::string& format, - TreeView* treeview_out); - -// Special case of CollapseNodesMatchingFunction matching formatted labels. -// See the comment on WeightBelowNodeMatchingFormatted. -void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, - const std::string& formatted, - TreeView* treeview_out); - -} // namespace profiler -} // namespace ruy - -#endif // RUY_PROFILER - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy.h b/tensorflow/lite/experimental/ruy/ruy/ruy.h deleted file mode 100644 index 783c410cf82..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the only Ruy header that users should #include. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/dispatch.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" - -namespace ruy { - -// Performs a multiplication of matrices. This is Ruy's only API entry point. -// Should be self-explanatory given the above documentation for each of Matrix, -// Spec and Context. -template -void Mul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst) { - DispatchMul( - lhs, rhs, spec, context, dst); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h b/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h deleted file mode 100644 index 0b24636ef06..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/prepack.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { - -// Low-level, explicit pre-packing API. -// -// The cost of packing an input matrix (either the LHS or RHS) is amortized -// across the non-depth dimension of the opposite input matrix. Thus, when the -// LHS has very few rows or the RHS has very few columns, the cost of packing -// the opposite input matrix can become significant. See pack.h for further -// information on packing. -// -// This file provides an API allowing a user to explicitly pack a matrix and -// reuse the pre-packed matrix, avoiding that cost. -// -// See example_prepack.cc for example usage. - -template -void PrePackForMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst, - PrepackedMatrix* prepacked_lhs, - PrepackedMatrix* prepacked_rhs, - std::function alloc_fn) { - SidePair prepacked(prepacked_lhs, prepacked_rhs); - PrePackForMulInternal(lhs, rhs, spec, context, dst, prepacked, - alloc_fn); -} - -template -void MulWithPrepacked(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - PrepackedMatrix* prepacked_lhs, - PrepackedMatrix* prepacked_rhs) { - SidePair prepacked(prepacked_lhs, prepacked_rhs); - MulWithPrepackedInternal(lhs, rhs, spec, context, dst, - prepacked); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl b/tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl deleted file mode 100644 index ef7e8b1bb79..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy_test.bzl +++ /dev/null @@ -1,34 +0,0 @@ -# Provides the ruy_test macro for type-parametrized tests. -"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination.""" - -def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None): - for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: - native.cc_test( - name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), - srcs = srcs, - copts = copts + [ - "-DRUY_TEST_LHSSCALAR=%s" % lhs, - "-DRUY_TEST_RHSSCALAR=%s" % rhs, - "-DRUY_TEST_ACCUMSCALAR=%s" % accum, - "-DRUY_TEST_DSTSCALAR=%s" % dst, - ], - deps = deps, - tags = tags, - ) - -def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None): - tags = ["req_dep=//third_party/gemmlowp:profiler"] - for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: - native.cc_binary( - name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), - testonly = True, - srcs = srcs, - copts = copts + [ - "-DRUY_TEST_LHSSCALAR=%s" % lhs, - "-DRUY_TEST_RHSSCALAR=%s" % rhs, - "-DRUY_TEST_ACCUMSCALAR=%s" % accum, - "-DRUY_TEST_DSTSCALAR=%s" % dst, - ], - deps = deps, - tags = tags, - ) diff --git a/tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl b/tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl deleted file mode 100644 index 5701fffa0f7..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/ruy_test_ext.bzl +++ /dev/null @@ -1,7 +0,0 @@ -"""Allows to specialize the ruy BUILD to availability of external libraries""" - -def ruy_test_ext_defines(): - return [] - -def ruy_test_ext_deps(): - return [] diff --git a/tensorflow/lite/experimental/ruy/ruy/side_pair.h b/tensorflow/lite/experimental/ruy/ruy/side_pair.h deleted file mode 100644 index a3210e27a53..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/side_pair.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -namespace ruy { - -// Enumeration of the sides, i.e. the operands 'slots', in a matrix -// multiplication. The numerical values of these enumeration constants matter -// because these will be used as indices into the array underlying a SidePair. -enum class Side { - // Left-hand side - kLhs = 0, - // Right-hand side - kRhs = 1 -}; - -// SidePair is a pair container where the two elements are indexed by a Side -// enum. -template -class SidePair final { - public: - SidePair() {} - SidePair(const T& a, const T& b) : elem_{a, b} {} - const T& operator[](Side side) const { - const int index = static_cast(side); - // Technically this check is vacuous, since other values would be - // out-of-range for enum Side. - RUY_DCHECK(index == 0 || index == 1); - return elem_[index]; - } - - T& operator[](Side side) { - const int index = static_cast(side); - // Technically this check is vacuous, since other values would be - // out-of-range for enum Side. - RUY_DCHECK(index == 0 || index == 1); - return elem_[index]; - } - - private: - static_assert(static_cast(Side::kLhs) == 0, ""); - static_assert(static_cast(Side::kRhs) == 1, ""); - T elem_[2]; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/size_util.h b/tensorflow/lite/experimental/ruy/ruy/size_util.h deleted file mode 100644 index 56dd095de85..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/size_util.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" - -#ifdef _WIN32 -#include -#endif - -namespace ruy { - -template -inline Integer floor_log2(Integer n) { - static_assert(std::is_integral::value, ""); - static_assert(std::is_signed::value, ""); - static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, ""); - - RUY_DCHECK_GE(n, 1); -#ifdef _WIN32 - unsigned long result; // NOLINT[runtime/int] - if (sizeof(Integer) == 4) { - _BitScanReverse(&result, n); - } else { - _BitScanReverse64(&result, n); - } - return result; -#else - if (sizeof(Integer) == 4) { - return 31 - __builtin_clz(n); - } else { - return 63 - __builtin_clzll(n); - } -#endif -} - -template -Integer ceil_log2(Integer n) { - RUY_DCHECK_GE(n, 1); - return n == 1 ? 0 : floor_log2(n - 1) + 1; -} - -template -bool is_pot(Integer value) { - return (value > 0) && ((value & (value - 1)) == 0); -} - -template -Integer pot_log2(Integer n) { - RUY_DCHECK(is_pot(n)); - return floor_log2(n); -} - -template -Integer round_down_pot(Integer value) { - return static_cast(1) << floor_log2(value); -} - -template -Integer round_up_pot(Integer value) { - return static_cast(1) << ceil_log2(value); -} - -template -Integer round_down_pot(Integer value, Modulo modulo) { - RUY_DCHECK_EQ(modulo & (modulo - 1), 0); - return value & ~(modulo - 1); -} - -template -Integer round_up_pot(Integer value, Modulo modulo) { - return round_down_pot(value + modulo - 1, modulo); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc b/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc deleted file mode 100644 index 442c31958cc..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/size_util_test.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" - -#include -#include -#include - -#include - -namespace ruy { -namespace { - -template -void SizeUtilTestValue(Integer value) { - if (value == 0) { - return; - } - - EXPECT_LE(0, floor_log2(value)); - EXPECT_LE(floor_log2(value), ceil_log2(value)); - EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer)); - - if (is_pot(value)) { - EXPECT_EQ(floor_log2(value), ceil_log2(value)); - EXPECT_EQ(floor_log2(value), pot_log2(value)); - } else { - EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value)); - } - EXPECT_EQ(value >> floor_log2(value), 1); - EXPECT_EQ(round_down_pot(value), static_cast(1) - << floor_log2(value)); - EXPECT_LE(round_down_pot(value), value); - EXPECT_GE(round_down_pot(value), value >> 1); - EXPECT_TRUE(is_pot(round_down_pot(value))); - - if (ceil_log2(value) < 8 * sizeof(Integer) - 1) { - EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0); - EXPECT_EQ(round_up_pot(value), static_cast(1) << ceil_log2(value)); - EXPECT_GE(round_up_pot(value), value); - EXPECT_LE(round_up_pot(value) >> 1, value); - EXPECT_TRUE(is_pot(round_up_pot(value))); - } - - for (std::uint8_t modulo : {1, 2, 8, 32, 128}) { - EXPECT_GE(value, round_down_pot(value, modulo)); - EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0); - - if (value <= std::numeric_limits::max() - modulo) { - EXPECT_LE(value, round_up_pot(value, modulo)); - EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0); - } - } -} - -template -void SizeUtilTest() { - for (int exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) { - const Integer pot = static_cast(1) << exponent; - SizeUtilTestValue(pot - 1); - SizeUtilTestValue(pot); - SizeUtilTestValue(pot + 1); - SizeUtilTestValue(pot + 12); - SizeUtilTestValue(pot + 123); - } - SizeUtilTestValue(std::numeric_limits::max() - 1); - SizeUtilTestValue(std::numeric_limits::max()); -} - -TEST(SizeUtilTest, Int) { SizeUtilTest(); } - -TEST(SizeUtilTest, Long) { SizeUtilTest(); } // NOLINT - -TEST(SizeUtilTest, LongLong) { SizeUtilTest(); } // NOLINT - -TEST(SizeUtilTest, Int32) { SizeUtilTest(); } - -TEST(SizeUtilTest, Int64) { SizeUtilTest(); } - -TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest(); } - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/spec.h b/tensorflow/lite/experimental/ruy/ruy/spec.h deleted file mode 100644 index 584d90ea047..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/spec.h +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ - -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/cpu_cache_size.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" - -namespace ruy { - -// Our 'general' loop structure (the default) involves multi-threading and -// complicated loops aiming to optimize cache-friendliness. One may opt out of -// this and pick the 'simple' loop structure instead, which only performs well -// for small matrix sizes and only allows using one thread, in exchange for -// smaller code size. -enum class LoopStructure { kGeneral, kSimple, kAuto }; - -// In general we allow zero_point's to have any Scalar value. This is called -// 'asymmetric' quantization. We do take advantage of the optimization -// opportunities when zero_points happen at runtime to be 'symmetric' (e.g. the -// int8 value 0 or the uint8 value 128), but we still generate code to handle -// the general asymmetric case. By choosing kSymmetric here, one opts out of -// this and supports only the symmetric case, in exchange for smaller code size. -enum class ZeroPointSupport { kGeneral, kSymmetric }; - -// In general we allow all Layout's, even if we may use slow paths for some -// kinds of layouts. By choosing kRCC, one may opt out of this and -// only keep support for the simplest and most efficient combination of -// Layout's, in exchange for smaller code size. The case covered by -// kRCC is where the storage orders are exactly the following: -// - LHS is RowMajor -// - RHS is ColMajor -// - Destination is ColMajor -enum class LayoutSupport { kGeneral, kRCC }; - -// A Spec describes all about a matrix multiplication operation that isn't -// encoded in the LHS, RHS and destination matrices. Some of that information -// is encoded as compile-time constants and types (for instance, the choice -// of accumulator type, AccumScalar). Some of that information is encoded as -// runtime values (for instance, the optional bias vector). -template -struct BasicSpec { - // Accumulator type. The type of accumulators used to compute the dot-products - // before being ultimately casted to the destination type. - using AccumScalar = tAccumScalar; - // The destination scalar type. - using DstScalar = tDstScalar; - // The bias vector data, if not null. - const AccumScalar* bias = nullptr; - // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) - // of the multiplier by which accumulators are multiplied before being casted - // to the destination type. - AccumScalar multiplier_fixedpoint = 0; - // Only for non-floating-point cases. The exponent part of the aforementioned - // multiplier. - int multiplier_exponent = 0; - // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must - // point to a buffer of as many values as there are rows in the destination - // matrix. Each row of the destination matrix will use the corresponding - // buffer element instead of multiplier_fixedpoint. - const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; - // Per-channel variant of multiplier_exponent. If not nullptr, this must - // point to a buffer of as many values as there are rows in the destination - // matrix. Each row of the destination matrix will use the corresponding - // buffer element instead of multiplier_exponent. - // - // Either none or both of multiplier_exponent_perchannel and - // multiplier_fixedpoint_perchannel must be nullptr. - const int* multiplier_exponent_perchannel = nullptr; - // min clamp bound of destination values. - DstScalar clamp_min = std::is_floating_point::value - ? -std::numeric_limits::infinity() - : std::numeric_limits::lowest(); - // max clamp bound of destination values. - DstScalar clamp_max = std::is_floating_point::value - ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - // See above enum LoopStructure - static constexpr LoopStructure kLoopStructure = LoopStructure::kAuto; - // See above enum LayoutSupport - static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kGeneral; - // See above enum ZeroPointSupport - static constexpr ZeroPointSupport kZeroPointSupport = - ZeroPointSupport::kGeneral; - // Testing-only, not meant to be used by actual users: - // Used for testing of various kernel layouts. - using StandardCppKernelLhsLayout = FixedKernelLayout; - using StandardCppKernelRhsLayout = FixedKernelLayout; - // Returns (a reasonable estimate of) the local CPU cache size. - // See ruy::LocalDataCacheSize() which returns some coarse, sane default for - // each CPU architecture. - // This may be overridden, either to provide more accurate/runtime values, - // or to test with other values to let testcases have more coverage. - static int local_data_cache_size() { return LocalDataCacheSize(); } - // Same as local_data_cache_size but for the total data cache size accessible - // to each CPU core. See ruy::SharedDataCacheSize(). - static int shared_data_cache_size() { return SharedDataCacheSize(); } -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/test.h b/tensorflow/lite/experimental/ruy/ruy/test.h deleted file mode 100644 index 305b5a844fa..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test.h +++ /dev/null @@ -1,2125 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/pmu.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy_advanced.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -#ifdef RUY_TEST_EXTERNAL_PATHS -#define EIGEN_USE_THREADS -#define EIGEN_USE_CUSTOM_THREAD_POOL -#include "third_party/eigen3/Eigen/Core" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "public/gemmlowp.h" -#include "third_party/lapack/blas.h" -#endif - -#ifdef RUY_PROFILER -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" -#endif - -namespace ruy { - -const float kClampRatio = 0.1f; - -enum class ExternalPath { kNone, kGemmlowp, kEigen, kEigenTensor, kOpenBlas }; - -inline std::vector* CoveredPaths() { - static std::vector covered_paths; - return &covered_paths; -} - -inline const char* PathName(Path path) { -#define RUY_PATHNAME_CASE(NAME) \ - case Path::NAME: \ - return #NAME; - switch (path) { - RUY_PATHNAME_CASE(kReference) - RUY_PATHNAME_CASE(kStandardCpp) -#if RUY_PLATFORM(NEON) - RUY_PATHNAME_CASE(kNeon) - RUY_PATHNAME_CASE(kNeonDotprod) -#elif RUY_PLATFORM(X86) - RUY_PATHNAME_CASE(kSse42) - RUY_PATHNAME_CASE(kAvx2) - RUY_PATHNAME_CASE(kAvx512) - RUY_PATHNAME_CASE(kAvxVnni) -#endif - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_PATHNAME_CASE -} - -inline const char* TuningName(Tuning tuning) { -#define RUY_SUBPATHNAME_CASE(NAME) \ - case Tuning::NAME: \ - return #NAME; - switch (tuning) { - RUY_SUBPATHNAME_CASE(kInOrder) - RUY_SUBPATHNAME_CASE(kOutOfOrder) - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_SUBPATHNAME_CASE -} - -inline const char* PathName(ExternalPath path) { -#define RUY_PATHNAME_CASE(NAME) \ - case ExternalPath::NAME: \ - return #NAME; - switch (path) { - RUY_PATHNAME_CASE(kGemmlowp) - RUY_PATHNAME_CASE(kEigen) - RUY_PATHNAME_CASE(kEigenTensor) - RUY_PATHNAME_CASE(kOpenBlas) - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_PATHNAME_CASE -} - -inline std::ostream& operator<<(std::ostream& stream, Path path) { - return stream << PathName(path); -} - -inline std::ostream& operator<<(std::ostream& stream, - ExternalPath external_path) { - return stream << PathName(external_path); -} - -template -std::string Join(const ContainerType& container) { - if (container.empty()) { - return ""; - } - std::ostringstream stream; - auto it = container.begin(); - stream << *it++; - for (; it != container.end(); ++it) { - stream << ", "; - stream << *it; - } - return stream.str(); -} - -struct LogCoveredPathsOnDestruction final { - ~LogCoveredPathsOnDestruction() { - std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl; - - // When testing on ARM64 ChromiumOS emulator, make sure that we covered - // the dotprod path. We're getting such coverage at the moment thanks to - // using a sufficiently recent emulator, and we don't want to regress that. -#if RUY_PLATFORM(ARM_64) && defined RUY_TESTING_ON_CHROMIUMOS - bool found_dotprod = false; - for (const std::string& covered_path : *CoveredPaths()) { - if (covered_path == "kNeonDotprod") { - found_dotprod = true; - } - } - if (!found_dotprod) { - std::cerr - << "Error: we haven't tested the kNeonDotprod path as we should " - "have. At the moment, this is required on ChromiumOS as this is " - "what we run emulator tests in, that currently supports " - "dot-product " - "instructions, and we care very much about not regressing that. " - "If this test was run in an emulator, please upgrade to a newer " - "emulator version. If this test was run on an actual device, and " - "you need to be able to run ruy tests on devices not supporting " - "dot-product instructions, get in touch with us.\n" - << std::endl; - abort(); - } -#endif - } - static void Singleton() { static LogCoveredPathsOnDestruction singleton; } -}; - -enum class RandomRange { - kGeneral, - kAvoidMinValue, - kOffCenterAvoidMinValue, - kReasonableSrcZeroPoint, - kReasonableDstZeroPoint, - kBias -}; - -template ::value> -struct RandomRangeBounds {}; - -template -struct RandomRangeBounds { - static Scalar GetMinBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return -1; - case RandomRange::kAvoidMinValue: - return -1; - case RandomRange::kOffCenterAvoidMinValue: - return -1; - case RandomRange::kReasonableSrcZeroPoint: - return 0; - case RandomRange::kReasonableDstZeroPoint: - return 0; - case RandomRange::kBias: - return -1; - default: - RUY_CHECK(false); - return 0; - } - } - static Scalar GetMaxBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return 1; - case RandomRange::kAvoidMinValue: - return 1; - case RandomRange::kOffCenterAvoidMinValue: - return 1; - case RandomRange::kReasonableSrcZeroPoint: - return 0; - case RandomRange::kReasonableDstZeroPoint: - return 0; - case RandomRange::kBias: - return 1; - default: - RUY_CHECK(false); - return 0; - } - } -}; - -template -Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) { - float sum = s1 * weight1 + s2 * weight2; - float clamped = std::min( - std::numeric_limits::max(), - std::max(std::numeric_limits::lowest(), sum)); - return static_cast(clamped); -} - -template -Scalar Parametrized(float param) { - return WeightedSum(std::numeric_limits::max(), param, - std::numeric_limits::lowest(), 1 - param); -} - -template -struct RandomRangeBounds { - static Scalar GetMinBound(RandomRange range) { - static constexpr double offcenteredness = - 0.02; // Shift lower limit by about 5 for range of 255. - switch (range) { - case RandomRange::kGeneral: - return std::numeric_limits::lowest(); - case RandomRange::kAvoidMinValue: - return 1 + std::numeric_limits::lowest(); - case RandomRange::kOffCenterAvoidMinValue: - return 1 + std::numeric_limits::lowest() + - static_cast( - offcenteredness * std::numeric_limits::max() - - offcenteredness * - (std::numeric_limits::lowest() + 1)); - case RandomRange::kReasonableSrcZeroPoint: - return std::numeric_limits::lowest(); - case RandomRange::kReasonableDstZeroPoint: - return Parametrized(0.4); - case RandomRange::kBias: - return std::is_same::value - ? static_cast(-10000) - : 0; - default: - RUY_CHECK(false); - return 0; - } - } - static Scalar GetMaxBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return std::numeric_limits::max(); - case RandomRange::kAvoidMinValue: - return std::numeric_limits::max(); - case RandomRange::kOffCenterAvoidMinValue: - return std::numeric_limits::max(); - case RandomRange::kReasonableSrcZeroPoint: - return std::numeric_limits::max(); - case RandomRange::kReasonableDstZeroPoint: - return Parametrized(0.6); - case RandomRange::kBias: - return std::is_same::value - ? static_cast(10000) - : 0; - default: - RUY_CHECK(false); - return 0; - } - } -}; - -inline std::default_random_engine& global_random_engine() { - static std::default_random_engine engine; - return engine; -} - -template -struct UniformRandomDistribution { - UniformRandomDistribution(RandomRange range) - : dist(RandomRangeBounds::GetMinBound(range), - RandomRangeBounds::GetMaxBound(range)) {} - Scalar Get() { return dist(global_random_engine()); } - // std::uniform_int_distribution is specified not to support char types, - // only short and wider types. MSVC actually generates an error on - // std::uniform_int_distribution. - using StdDistType = typename std::conditional< - std::is_floating_point::value, - std::uniform_real_distribution, - std::uniform_int_distribution>::type; - StdDistType dist; -}; - -template -void MakeRandomScalar(UniformRandomDistribution* uniform_dist, - Scalar* dst) { - *dst = uniform_dist->Get(); -} - -template -void MakeRandomVector(UniformRandomDistribution* uniform_dist, int size, - std::vector* dst) { - dst->resize(size); - for (auto& x : *dst) { - MakeRandomScalar(uniform_dist, &x); - } -} - -template -void MakeRandomScalar(RandomRange range, Scalar* dst) { - UniformRandomDistribution dist(range); - *dst = dist.Get(); - if (range == RandomRange::kReasonableDstZeroPoint || - range == RandomRange::kReasonableSrcZeroPoint) { - if (global_random_engine()() & 1) { - *dst = SymmetricZeroPoint(); - } - } -} - -template -void MakeRandomVector(RandomRange range, int size, std::vector* dst) { - UniformRandomDistribution dist(range); - dst->resize(size); - for (auto& x : *dst) { - MakeRandomScalar(&dist, &x); - } -} - -enum class LayoutStyle { kPackedLinear, kLinear }; - -inline void MakeLayout(int rows, int cols, Order order, - LayoutStyle layout_style, Layout* layout) { - layout->rows = rows; - layout->cols = cols; - layout->order = order; - - const int packed_stride = order == Order::kColMajor ? rows : cols; - - RUY_CHECK(layout_style == LayoutStyle::kPackedLinear || - layout_style == LayoutStyle::kLinear); - if (layout_style == LayoutStyle::kPackedLinear) { - layout->stride = packed_stride; - } else { - layout->stride = packed_stride + 1; - } -} - -template -struct StorageMatrix { - StorageMatrix() = default; - StorageMatrix(const StorageMatrix&) = delete; - void operator=(const StorageMatrix&) = delete; - std::vector data; - Matrix matrix; -}; - -template -void VerifyConsistentFields(const StorageMatrix& storage_matrix) { - if (storage_matrix.data.empty()) { - RUY_CHECK_EQ(storage_matrix.matrix.data.get(), nullptr); - RUY_CHECK_EQ(storage_matrix.matrix.layout.rows, 0); - RUY_CHECK_EQ(storage_matrix.matrix.layout.cols, 0); - } else { - RUY_CHECK_EQ(storage_matrix.matrix.data.get(), storage_matrix.data.data()); - RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout), - storage_matrix.data.size()); - } -} - -template -void MakeRandom(int rows, int cols, Order order, Scalar zero_point, - LayoutStyle layout_style, RandomRange range, - StorageMatrix* storage_matrix) { - MakeLayout(rows, cols, order, layout_style, &storage_matrix->matrix.layout); - storage_matrix->matrix.zero_point = zero_point; - UniformRandomDistribution data_dist(range); - MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout), - &storage_matrix->data); - storage_matrix->matrix.data = storage_matrix->data.data(); - VerifyConsistentFields(*storage_matrix); -} - -template -struct TestResult { - void operator=(const TestResult&) = delete; - void operator=(const TestResult&&) = delete; - StorageMatrix storage_matrix; - Path path = Path::kNone; - Tuning tuning = Tuning::kAuto; - ExternalPath external_path = ExternalPath::kNone; - float latency; - float l1_refill_rate; - float l2_refill_rate; - float l3_refill_rate; - float l1tlb_refill_rate; - float l2tlb_refill_rate; - float mispred_rate; - float frontend_stall_rate; - float backend_stall_rate; - - // Per-path data for pre-packing. - // This is not used by external paths or by Path::kReference. - Allocator allocator; - PrepackedMatrix prepacked_lhs; - PrepackedMatrix prepacked_rhs; - bool use_prepacked_lhs = false; - bool use_prepacked_rhs = false; -}; - -template -std::string PathName(const TestResult& result) { - std::string pathname; - if (result.path != Path::kNone) { - pathname.assign(PathName(result.path)); - } else if (result.external_path != ExternalPath::kNone) { - pathname.assign(PathName(result.external_path)); - } else { - RUY_CHECK(false); - } - if (result.tuning != Tuning::kAuto) { - pathname.append("/"); - pathname.append(TuningName(result.tuning)); - } - return pathname; -} - -enum class ExpectedOutcome { kSuccess, kDeath }; - -template -struct TestSet final { - using LhsScalar = tLhsScalar; - using RhsScalar = tRhsScalar; - using AccumScalar = typename SpecType::AccumScalar; - using DstScalar = typename SpecType::DstScalar; - using Spec = SpecType; - using TestResultType = TestResult; - - void Run() { - MakeZeroPoints(); - MakeLhsRhs(); - MakeSpec(); - MakeOtherParams(); - MakeResultPaths(); - MakePrepackedMatrices(); - Eval(); - Verify(); - } - - private: - void MakeZeroPoints(); - void MakeLhsRhs(); - void MakeSpec(); - void MakeResultPaths(); - void MakePrepackedMatrices(); - void MakeOtherParams(); - void EvalAndVerify(); - void Eval(); - void Verify(); - - void EvalResult(TestResultType* result); - void EvalRuy(TestResultType* result); - void DoMul(TestResultType* result); - void Benchmark(TestResultType* result); - void VerifyTestResults() const; - - public: - enum class LifeStage { - kInitial, - kHasZeroPoints, - kHasLhsRhs, - kHasSpec, - kHasOtherParams, - kHasResultPaths, - kHasPrepackedMatrices, - kEvaluated, - kFinal - }; - - ~TestSet() { - RUY_CHECK_EQ(life_stage, LifeStage::kFinal); - LogCoveredPathsOnDestruction::Singleton(); - } - - LifeStage life_stage = LifeStage::kInitial; - - int rows = 0; - int cols = 0; - int depth = 0; - Order lhs_order = Order::kRowMajor; - Order rhs_order = Order::kColMajor; - Order dst_order = Order::kColMajor; - LayoutStyle layout_style = LayoutStyle::kPackedLinear; - ExpectedOutcome expected_outcome = ExpectedOutcome::kSuccess; - - bool use_specified_zero_points = false; - LhsScalar lhs_zero_point = 0; - RhsScalar rhs_zero_point = 0; - DstScalar dst_zero_point = 0; - - std::vector per_channel_multiplier_fixedpoint; - std::vector per_channel_multiplier_exponent; - - StorageMatrix lhs; - StorageMatrix rhs; - Spec spec; - std::vector bias_data; - std::vector> results; - - std::vector paths; - std::vector external_paths; - - bool benchmark = false; - bool perchannel = false; - int max_num_threads = 0; - bool benchmark_prepack_lhs = false; - bool benchmark_prepack_rhs = false; -}; - -inline PmuEvents& GlobalPmuEvents() { - static PmuEvents pmu; - return pmu; -} - -inline Context& GlobalContext() { - // Ensure that GlobalPmuEvents is constructed before we create any context. - // This ensures that pmu counters are opened before we create any worker - // thread, which is necessary to count events from worker threads. - GlobalPmuEvents(); - - static Context context; - return context; -} - -#if defined(__has_feature) -#if __has_feature(thread_sanitizer) -#define RUY_TSAN -#endif -#if __has_feature(address_sanitizer) -#define RUY_ASAN -#endif -#endif // defined(__has_feature) - -template -void TestSet::DoMul(TestResultType* result) { - Context* context = &GlobalContext(); - - if (!result->use_prepacked_lhs && !result->use_prepacked_rhs) { - Mul(lhs.matrix, rhs.matrix, spec, context, - &result->storage_matrix.matrix); - return; - } - - // If we prepacked an input matrix, null out its data pointer to check - // that we don't access any data through it. - Matrix null_data_lhs = lhs.matrix; - Matrix null_data_rhs = rhs.matrix; - if (result->use_prepacked_lhs) { - null_data_lhs.data = nullptr; - } - if (result->use_prepacked_rhs) { - null_data_rhs.data = nullptr; - } - - // Do the multiplication with pre-packed matrices. - PrepackedMatrix* prepacked_lhs_ptr = - result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; - PrepackedMatrix* prepacked_rhs_ptr = - result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; - MulWithPrepacked(null_data_lhs, null_data_rhs, spec, context, - &result->storage_matrix.matrix, prepacked_lhs_ptr, - prepacked_rhs_ptr); -} - -// When building for WAsm, ASSERT_DEATH is not defined. -#ifdef ASSERT_DEATH -#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE) -#else -#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) -#endif - -template -void TestSet::EvalRuy(TestResultType* result) { - GlobalContext().explicit_tuning = result->tuning; - if (max_num_threads) { - GlobalContext().max_num_threads = max_num_threads; - } else if (benchmark) { - GlobalContext().max_num_threads = 1; - } else { - GlobalContext().max_num_threads = 1 + global_random_engine()() % 8; - } - GlobalContext().SetRuntimeEnabledPaths(result->path); - if (expected_outcome == ExpectedOutcome::kSuccess) { - DoMul(result); - RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); - } else if (expected_outcome == ExpectedOutcome::kDeath) { - // TODO(benoitjacob) TSan and ASan seem to be breaking ASSERT_DEATH. - // Report a bug? -#if (!defined NDEBUG) && (!defined RUY_ASAN) && (!defined RUY_TSAN) - RUY_ASSERT_DEATH(DoMul(result), ""); -#endif - } else { - RUY_CHECK(false); - } - GlobalContext().explicit_tuning = Tuning::kAuto; - GlobalContext().max_num_threads = 1; -} - -#ifdef RUY_TEST_EXTERNAL_PATHS - -template -void WrapGemmlowp(const Matrix& src, - gemmlowp::MatrixMap* dst) { - RUY_CHECK(src.layout.order == (tOrder == gemmlowp::MapOrder::ColMajor - ? Order::kColMajor - : Order::kRowMajor)); - *dst = gemmlowp::MatrixMap( - src.data.get(), src.layout.rows, src.layout.cols, src.layout.stride); -} - -template -void WrapGemmlowpMutable(Matrix* src, - gemmlowp::MatrixMap* dst) { - RUY_CHECK(src->layout.order == (tOrder == gemmlowp::MapOrder::ColMajor - ? Order::kColMajor - : Order::kRowMajor)); - *dst = gemmlowp::MatrixMap( - src->data.get(), src->layout.rows, src->layout.cols, src->layout.stride); -} - -template -struct GemmlowpOrder {}; - -template <> -struct GemmlowpOrder { - static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor; -}; - -template <> -struct GemmlowpOrder { - static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor; -}; - -inline gemmlowp::GemmContext& GlobalGemmlowpContext() { - static gemmlowp::GemmContext context; - return context; -} - -template -void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder = - GemmlowpOrder::kValue; - static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder = - GemmlowpOrder::kValue; - static constexpr gemmlowp::MapOrder kGemmlowpDstOrder = - GemmlowpOrder::kValue; - gemmlowp::MatrixMap gemmlowp_lhs; - gemmlowp::MatrixMap gemmlowp_rhs; - gemmlowp::MatrixMap gemmlowp_dst; - WrapGemmlowp(lhs, &gemmlowp_lhs); - WrapGemmlowp(rhs, &gemmlowp_rhs); - WrapGemmlowpMutable(dst, &gemmlowp_dst); - - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; - quantize_down_stage.result_offset_after_shift = dst->zero_point; - quantize_down_stage.result_fixedpoint_multiplier = spec.multiplier_fixedpoint; - quantize_down_stage.result_exponent = spec.multiplier_exponent; - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< - gemmlowp::VectorShape::Col> - quantize_down_stage_pc; - quantize_down_stage_pc.result_offset_after_shift = dst->zero_point; - using ColVectorMap = - gemmlowp::VectorMap; - quantize_down_stage_pc.result_fixedpoint_multiplier = - ColVectorMap(spec.multiplier_fixedpoint_perchannel, lhs.layout.rows); - quantize_down_stage_pc.result_exponent = - ColVectorMap(spec.multiplier_exponent_perchannel, lhs.layout.rows); - - gemmlowp::OutputStageClamp clamp_stage; - clamp_stage.min = spec.clamp_min; - clamp_stage.max = spec.clamp_max; - using OutputStageSaturatingCast = typename std::conditional< - std::is_same::value, - gemmlowp::OutputStageSaturatingCastToUint8, - gemmlowp::OutputStageSaturatingCastToInt16>::type; - OutputStageSaturatingCast saturating_cast_stage; - - GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads - : 1); - if (spec.bias) { - using ColVectorMap = - gemmlowp::VectorMap; - gemmlowp::OutputStageBiasAddition bias_add_stage; - bias_add_stage.bias_vector = ColVectorMap(spec.bias, dst->layout.rows); -#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE - if (spec.multiplier_exponent_perchannel) { - const auto& output_pipeline = - std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage, - saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } else // NOLINT[readability/braces] -#endif - { - const auto& output_pipeline = - std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage, - saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } - } else { -#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE - if (spec.multiplier_exponent_perchannel) { - const auto& output_pipeline = std::make_tuple( - quantize_down_stage_pc, clamp_stage, saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } else // NOLINT[readability/braces] -#endif - { - const auto& output_pipeline = std::make_tuple( - quantize_down_stage, clamp_stage, saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } - } -} - -inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) { - return (LhsOrder == Order::kRowMajor ? 4 : 0) + - (RhsOrder == Order::kRowMajor ? 2 : 0) + - (DstOrder == Order::kRowMajor ? 1 : 0); -} - -template -void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalGemmlowp(lhs, rhs, spec, max_num_threads, dst); -#define EVALGEMMLOWP_CASE2(LHS, RHS) \ - EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \ - EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALGEMMLOWP_CASE1(LHS) \ - EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \ - EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor) - - EVALGEMMLOWP_CASE1(Order::kColMajor) - EVALGEMMLOWP_CASE1(Order::kRowMajor) - -#undef EVALGEMMLOWP_CASE1 -#undef EVALGEMMLOWP_CASE2 -#undef EVALGEMMLOWP_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -struct EigenOrder {}; - -template <> -struct EigenOrder { - static constexpr int kValue = Eigen::ColMajor; -}; - -template <> -struct EigenOrder { - static constexpr int kValue = Eigen::RowMajor; -}; - -template -void EvalEigen(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - static constexpr int kEigenLhsOrder = EigenOrder::kValue; - static constexpr int kEigenRhsOrder = EigenOrder::kValue; - static constexpr int kEigenDstOrder = EigenOrder::kValue; - - using EigenLhsType = typename Eigen::Matrix:: - template StridedConstMapType>::type; - using EigenRhsType = typename Eigen::Matrix:: - template StridedConstMapType>::type; - using EigenDstType = typename Eigen::Matrix:: - template StridedMapType>::type; - using EigenBiasType = - typename Eigen::Matrix::ConstMapType; - - EigenLhsType eigen_lhs(lhs.data.get(), lhs.layout.rows, lhs.layout.cols, - Eigen::OuterStride(lhs.layout.stride)); - EigenRhsType eigen_rhs(rhs.data.get(), rhs.layout.rows, rhs.layout.cols, - Eigen::OuterStride(rhs.layout.stride)); - EigenDstType eigen_dst( - dst->data.get(), dst->layout.rows, dst->layout.cols, - Eigen::OuterStride(dst->layout.stride)); - Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); - - if (spec.bias) { - EigenBiasType eigen_bias(spec.bias, dst->layout.rows); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias; - } else { - eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = eigen_lhs * eigen_rhs; - } else { - eigen_dst.noalias() = (eigen_lhs * eigen_rhs) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } -} - -template -void EvalEigen(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALEIGEN_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalEigen(lhs, rhs, spec, max_num_threads, dst); -#define EVALEIGEN_CASE2(LHS, RHS) \ - EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \ - EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALEIGEN_CASE1(LHS) \ - EVALEIGEN_CASE2(LHS, Order::kColMajor) \ - EVALEIGEN_CASE2(LHS, Order::kRowMajor) - - EVALEIGEN_CASE1(Order::kColMajor) - EVALEIGEN_CASE1(Order::kRowMajor) - -#undef EVALEIGEN_CASE1 -#undef EVALEIGEN_CASE2 -#undef EVALEIGEN_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - // Eigen::TensorMap only supports packed layouts - RUY_CHECK(IsPacked(lhs.layout)); - RUY_CHECK(IsPacked(rhs.layout)); - RUY_CHECK(IsPacked(dst->layout)); - - using TensorLhsType = - Eigen::TensorMap>; - using TensorRhsType = - Eigen::TensorMap>; - using TensorDstType = - Eigen::TensorMap>; - using TensorBiasType = - Eigen::TensorMap>; - - const bool tr = DstOrder == Order::kRowMajor; - const auto& contract_lhs = tr ? rhs : lhs; - const auto& contract_rhs = tr ? lhs : rhs; - - TensorLhsType tensor_lhs( - contract_lhs.data.get(), - LhsOrder == Order::kColMajor ? contract_lhs.layout.rows - : contract_lhs.layout.cols, - LhsOrder == Order::kColMajor ? contract_lhs.layout.cols - : contract_lhs.layout.rows); - TensorRhsType tensor_rhs( - contract_rhs.data.get(), - RhsOrder == Order::kColMajor ? contract_rhs.layout.rows - : contract_rhs.layout.cols, - RhsOrder == Order::kColMajor ? contract_rhs.layout.cols - : contract_rhs.layout.rows); - TensorDstType tensor_dst( - dst->data.get(), - DstOrder == Order::kColMajor ? dst->layout.rows : dst->layout.cols, - DstOrder == Order::kColMajor ? dst->layout.cols : dst->layout.rows); - using DimPair = - typename Eigen::Tensor::DimensionPair; - Eigen::array contract_dims( - {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0, - (RhsOrder == Order::kColMajor) ? 0 : 1)}); - Eigen::array shuffle(DstOrder == Order::kColMajor ? 0 : 1, - DstOrder == Order::kColMajor ? 1 : 0); - static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1); - static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); - if (spec.bias) { - TensorBiasType tensor_bias(spec.bias, dst->layout.rows); - Eigen::array bias_2d_shape(tr ? 1 : dst->layout.rows, - tr ? dst->layout.rows : 1); - Eigen::array bcast(tr ? dst->layout.cols : 1, - tr ? 1 : dst->layout.cols); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - tensor_dst.device(device) = - tensor_lhs.contract(tensor_rhs, contract_dims); - } else { - tensor_dst.device(device) = - (tensor_lhs.contract(tensor_rhs, contract_dims) + - tensor_bias.reshape(bias_2d_shape).broadcast(bcast)) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - tensor_dst.device(device) = - tensor_lhs.contract(tensor_rhs, contract_dims); - } else { - tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } -} - -template -void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalEigenTensor(lhs, rhs, spec, max_num_threads, dst); -#define EVALEIGENTENSOR_CASE2(LHS, RHS) \ - EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \ - EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALEIGENTENSOR_CASE1(LHS) \ - EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \ - EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor) - - EVALEIGENTENSOR_CASE1(Order::kColMajor) - EVALEIGENTENSOR_CASE1(Order::kRowMajor) - -#undef EVALEIGENTENSOR_CASE1 -#undef EVALEIGENTENSOR_CASE2 -#undef EVALEIGENTENSOR_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -struct GenericBlasGemm {}; - -template <> -struct GenericBlasGemm { - static void Run(char* transa, char* transb, lapack::integer* m, - lapack::integer* n, lapack::integer* k, - lapack::doublereal* alpha, lapack::doublereal* a, - lapack::integer* lda, lapack::doublereal* b, - lapack::integer* ldb, lapack::doublereal* beta, - lapack::doublereal* c, lapack::integer* ldc) { - dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } -}; - -template <> -struct GenericBlasGemm { - static void Run(char* transa, char* transb, lapack::integer* m, - lapack::integer* n, lapack::integer* k, lapack::real* alpha, - lapack::real* a, lapack::integer* lda, lapack::real* b, - lapack::integer* ldb, lapack::real* beta, lapack::real* c, - lapack::integer* ldc) { - sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } -}; - -template -void EvalOpenBlas(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - Matrix gemm_lhs; - Matrix gemm_rhs; - Matrix gemm_dst; - gemm_dst = *dst; - - // Use Transpose to reduce to the all-column-major case. - // Notice that ruy::Matrix merely holds a pointer, does not own data, - // so Transpose is cheap -- no actual matrix data is being transposed here. - if (dst->layout.order == Order::kColMajor) { - gemm_lhs = lhs; - gemm_rhs = rhs; - } else { - gemm_lhs = rhs; - gemm_rhs = lhs; - Transpose(&gemm_lhs); - Transpose(&gemm_rhs); - Transpose(&gemm_dst); - } - bool transposed_lhs = false; - bool transposed_rhs = false; - - if (gemm_lhs.layout.order == Order::kRowMajor) { - Transpose(&gemm_lhs); - transposed_lhs = true; - } - if (gemm_rhs.layout.order == Order::kRowMajor) { - Transpose(&gemm_rhs); - transposed_rhs = true; - } - - RUY_CHECK_EQ(gemm_lhs.layout.order, Order::kColMajor); - RUY_CHECK_EQ(gemm_rhs.layout.order, Order::kColMajor); - RUY_CHECK_EQ(gemm_dst.layout.order, Order::kColMajor); - - char transa = transposed_lhs ? 'T' : 'N'; - char transb = transposed_rhs ? 'T' : 'N'; - int m = gemm_lhs.layout.rows; - int n = gemm_rhs.layout.cols; - int k = gemm_lhs.layout.cols; - float alpha = 1; - Scalar* a = gemm_lhs.data.get(); - int lda = gemm_lhs.layout.stride; - Scalar* b = gemm_rhs.data.get(); - int ldb = gemm_rhs.layout.stride; - float beta = 0; - Scalar* c = gemm_dst.data.get(); - int ldc = gemm_dst.layout.stride; - GenericBlasGemm::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, - &ldb, &beta, c, &ldc); - - // BLAS does not allow us to express the bias-addition and clamping, so - // we use Eigen for that. - - using EigenDstType = - typename Eigen::Matrix:: - template StridedMapType>::type; - using EigenBiasType = - typename Eigen::Matrix::ConstMapType; - - EigenDstType eigen_dst( - gemm_dst.data.get(), gemm_dst.layout.rows, gemm_dst.layout.cols, - Eigen::OuterStride(gemm_dst.layout.stride)); - Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); - - if (spec.bias) { - EigenBiasType eigen_bias(spec.bias, dst->layout.rows); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias; - } else { - eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - } else { - eigen_dst.noalias() = - eigen_dst.cwiseMin(spec.clamp_max).cwiseMax(spec.clamp_min); - } - } -} - -template -struct SupportsGemmlowp { - static constexpr bool kValue = - std::is_same::value && - std::is_same::value; -}; - -template -struct UsesSingleScalarType { - static constexpr bool kValue = - std::is_same::value && - std::is_same::value && - std::is_same::value; -}; - -template ::value, - bool EnableGemmlowp = SupportsGemmlowp::kValue, - bool SingleScalarType = UsesSingleScalarType::kValue> -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType*, TestResult*) { RUY_CHECK(false); } -}; - -template -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType* test_set, TestResult* test_result) { - if (test_result->external_path == ExternalPath::kEigen) { - EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, &test_result->storage_matrix.matrix); - } else if (test_result->external_path == ExternalPath::kEigenTensor) { - EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix, - test_set->spec, test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else if (test_result->external_path == ExternalPath::kOpenBlas) { - EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else { - RUY_CHECK(false); - } - } -}; - -template -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType* test_set, TestResult* test_result) { - if (test_result->external_path == ExternalPath::kGemmlowp) { - EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else { - RUY_CHECK(false); - } - } -}; - -template -void EvalExternalPath( - TestSetType* test_set, - TestResult* test_result) { - EvalExternalPathImpl::Run(test_set, test_result); -} - -#endif // RUY_TEST_EXTERNAL_PATHS - -template -bool Agree(const Matrix& matrix1, const Matrix& matrix2, - int depth) { - RUY_CHECK_EQ(matrix1.layout.rows, matrix2.layout.rows); - RUY_CHECK_EQ(matrix1.layout.cols, matrix2.layout.cols); - RUY_CHECK_EQ(matrix1.zero_point, matrix2.zero_point); - const int size = matrix1.layout.rows * matrix1.layout.cols; - double tolerated_max_diff = 0; - double tolerated_mean_diff = 0; - if (std::is_floating_point::value) { - // TODO: replace hardcoded 100 by something more sensible, probably - // roughly sqrt(depth) based on central limit theorem. - double max_abs_val = 0; - for (int row = 0; row < matrix1.layout.rows; row++) { - for (int col = 0; col < matrix1.layout.cols; col++) { - max_abs_val = - std::max(max_abs_val, - std::abs(static_cast(Element(matrix1, row, col)))); - max_abs_val = - std::max(max_abs_val, - std::abs(static_cast(Element(matrix2, row, col)))); - } - } - tolerated_max_diff = max_abs_val * std::numeric_limits::epsilon() * - 64 * std::sqrt(static_cast(depth)); - tolerated_mean_diff = tolerated_max_diff / std::sqrt(size); - } else if (RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)) { - tolerated_max_diff = 1; - // totally empirical - tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.2)); - } - double sum_diff = 0; - for (int row = 0; row < matrix1.layout.rows; row++) { - for (int col = 0; col < matrix1.layout.cols; col++) { - double elem1 = Element(matrix1, row, col); - double elem2 = Element(matrix2, row, col); - double diff = elem1 - elem2; - - sum_diff += diff; - // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is - // NaN. - if (!(std::abs(diff) <= tolerated_max_diff)) { - return false; - } - } - } - double mean_diff = sum_diff / size; - if (std::abs(mean_diff) > tolerated_mean_diff) { - return false; - } - return true; -} - -template -bool Agree(const StorageMatrix& storage_matrix1, - const StorageMatrix& storage_matrix2, int depth) { - VerifyConsistentFields(storage_matrix1); - VerifyConsistentFields(storage_matrix2); - return Agree(storage_matrix1.matrix, storage_matrix2.matrix, depth); -} - -template -bool Agree(const TestResult& result1, const TestResult& result2, - int depth) { - return Agree(result1.storage_matrix, result2.storage_matrix, depth); -} - -struct Stats { - double median; - double mean; - double min; - double max; -}; - -inline std::string StatsAsString(const Stats& stats) { - char buf[256]; - snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)", - stats.median, stats.mean, stats.min, stats.max); - return std::string(buf); -} - -template -void GetMatrixStats(const Matrix& matrix, Stats* stats) { - double min = std::numeric_limits::infinity(); - double max = -std::numeric_limits::infinity(); - double sum = 0; - std::vector allvals; - for (int row = 0; row < matrix.layout.rows; row++) { - for (int col = 0; col < matrix.layout.cols; col++) { - double val = Element(matrix, row, col); - min = std::min(min, val); - max = std::max(max, val); - sum += val; - allvals.push_back(val); - } - } - std::sort(allvals.begin(), allvals.end()); - stats->min = min; - stats->max = max; - stats->mean = sum / allvals.size(); - stats->median = allvals[allvals.size() / 2]; -} - -struct ErrorAnalysis { - Stats stats_good; - Stats stats_bad; - // The below is to help document departure from bit exactness. It's probably - // not going to be relevant to floating-point. - std::set error_rows; - std::set error_cols; - int row_of_first_error = 0; - int col_of_first_error = 0; - double first_error_good_value = 0; - double first_error_bad_value = 0; -}; - -template -void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index, - ErrorAnalysis* error_analysis) { - const auto& good_matrix = test_set.results[0]->storage_matrix.matrix; - const auto& bad_matrix = - test_set.results[first_bad_result_index]->storage_matrix.matrix; - GetMatrixStats(good_matrix, &error_analysis->stats_good); - GetMatrixStats(bad_matrix, &error_analysis->stats_bad); - bool found_first_error = false; - for (int row = 0; row < good_matrix.layout.rows; row++) { - for (int col = 0; col < good_matrix.layout.cols; col++) { - if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) { - if (!found_first_error) { - found_first_error = true; - error_analysis->row_of_first_error = row; - error_analysis->col_of_first_error = col; - error_analysis->first_error_good_value = - Element(good_matrix, row, col); - error_analysis->first_error_bad_value = Element(bad_matrix, row, col); - } - error_analysis->error_rows.insert(row); - error_analysis->error_cols.insert(col); - } - } - } -} - -template -void ComputeReasonableMultiplier( - const Matrix& lhs, - const Matrix& rhs, double* multiplier) { - using LhsScalar = typename TestSetType::LhsScalar; - using RhsScalar = typename TestSetType::RhsScalar; - using DstScalar = typename TestSetType::DstScalar; - if (std::is_floating_point::value || - std::is_same::value) { - *multiplier = 0; - return; - } - *multiplier = static_cast(std::numeric_limits::max()) / - (static_cast(lhs.layout.cols) * - std::numeric_limits::max() * - std::numeric_limits::max()); -} - -inline void QuantizeMultiplier(double multiplier_double, - std::int32_t* multiplier_fixedpoint, - int* multiplier_exponent) { - RUY_CHECK_GT(multiplier_double, 0); - if (multiplier_double == 0.) { - *multiplier_fixedpoint = 0; - *multiplier_exponent = 0; - return; - } - const double q = std::frexp(multiplier_double, multiplier_exponent); - auto q_fixed = static_cast(std::round(q * (1ll << 31))); - RUY_CHECK_LE(q_fixed, (1ll << 31)); - if (q_fixed == (1ll << 31)) { - q_fixed /= 2; - ++*multiplier_exponent; - } - RUY_CHECK_LE(q_fixed, std::numeric_limits::max()); - *multiplier_fixedpoint = static_cast(q_fixed); -} - -template -void SwitchMultiplierToPerChannel(TestSetType* test_set) { - test_set->per_channel_multiplier_fixedpoint.resize(test_set->rows); - test_set->per_channel_multiplier_exponent.resize(test_set->rows); - for (int i = 0; i < test_set->rows; i++) { - // multipliers typically range in [2^30 ; 2^31 - 1]. - // Values in [0, 2^30 - 1] are normally unused, but harmless. - // Thus a good way to randomize multipliers is to subtract from them - // a random value smaller than 2^30 but still significant compared to it. - std::int32_t nudged_multiplier = test_set->spec.multiplier_fixedpoint - - (global_random_engine()() % (1 << 26)); - int nudged_exponent = - test_set->spec.multiplier_exponent - 1 + (global_random_engine()() % 4); - test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier; - test_set->per_channel_multiplier_exponent[i] = nudged_exponent; - } - test_set->spec.multiplier_fixedpoint_perchannel = - test_set->per_channel_multiplier_fixedpoint.data(); - test_set->spec.multiplier_exponent_perchannel = - test_set->per_channel_multiplier_exponent.data(); - test_set->spec.multiplier_fixedpoint = 0; - test_set->spec.multiplier_exponent = 0; -} - -template < - typename TestSetType, - bool IsApplicable = - std::is_same::value && - !std::is_same::value> -struct MakeSpecMultiplierFieldsImpl {}; - -template -struct MakeSpecMultiplierFieldsImpl { - static void Run(TestSetType* test_set) { - double multiplier; - ComputeReasonableMultiplier(test_set->lhs.matrix, - test_set->rhs.matrix, &multiplier); - QuantizeMultiplier(multiplier, &test_set->spec.multiplier_fixedpoint, - &test_set->spec.multiplier_exponent); - if (!test_set->benchmark) { - test_set->perchannel = global_random_engine()() & 1; - } - if (test_set->perchannel) { - SwitchMultiplierToPerChannel(test_set); - } - } -}; - -template -struct MakeSpecMultiplierFieldsImpl { - static void Run(TestSetType* test_set) { - test_set->spec.multiplier_fixedpoint = 0; - test_set->spec.multiplier_exponent = 0; - } -}; - -template -void MakeSpecClampFields(Spec* spec) { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - - if (std::is_same::value) { - // Returning raw accumulators, clamping is not supported. - spec->clamp_min = std::numeric_limits::lowest(); - spec->clamp_max = std::numeric_limits::max(); - return; - } - - if (getenv("BENCHMARK_ONLY_MATMUL")) { - if (std::is_floating_point::value) { - spec->clamp_min = -std::numeric_limits::infinity(); - spec->clamp_max = std::numeric_limits::infinity(); - } else { - spec->clamp_min = std::numeric_limits::lowest(); - spec->clamp_max = std::numeric_limits::max(); - } - return; - } - - spec->clamp_min = std::numeric_limits::lowest() + 1; - spec->clamp_max = std::numeric_limits::max() - 1; -} - -template -void TestSet::MakeZeroPoints() { - RUY_CHECK_EQ(life_stage, LifeStage::kInitial); - if (!benchmark && !use_specified_zero_points) { - MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point); - MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point); - // If destination is std::int32_t, no dst_zero_point is necessary. - if (std::is_same::value) { - dst_zero_point = 0; - } else { - MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point); - } - } - life_stage = LifeStage::kHasZeroPoints; -} - -template -void TestSet::MakeLhsRhs() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints); - MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style, - RandomRange::kOffCenterAvoidMinValue, &lhs); - MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style, - RandomRange::kGeneral, &rhs); - life_stage = LifeStage::kHasLhsRhs; -} - -template -void TestSet::MakeSpec() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs); - - if (!getenv("BENCHMARK_ONLY_MATMUL") && - (benchmark || (global_random_engine()() & 1))) { - MakeRandomVector(RandomRange::kBias, rows, &bias_data); - spec.bias = bias_data.data(); - } - if (lhs.matrix.zero_point == std::numeric_limits::lowest() && - rhs.matrix.zero_point == std::numeric_limits::lowest()) { - lhs.matrix.zero_point += 1; - } - MakeSpecMultiplierFieldsImpl::Run(this); - MakeSpecClampFields(&spec); - life_stage = LifeStage::kHasSpec; -} - -inline int GetIntEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stoi(val); -} - -inline float GetFloatEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stof(val); -} - -inline int GetHexIntEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stoi(val, nullptr, 16); -} - -inline bool GetBoolEnvVarOrFalse(const char* name) { - return static_cast(GetIntEnvVarOrZero(name)); -} - -template -void TestSet::MakeOtherParams() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasSpec); - if (max_num_threads == 0) { - max_num_threads = GetIntEnvVarOrZero("THREADS"); - } - life_stage = LifeStage::kHasOtherParams; -} - -inline std::vector PathsBitfieldAsVector(Path paths_bitfield) { - std::vector result; - std::uint32_t remaining_paths = static_cast(paths_bitfield); - std::uint32_t test_bit = 1; - while (remaining_paths) { - if (remaining_paths & test_bit) { - result.push_back(static_cast(test_bit)); - } - remaining_paths &= ~test_bit; - test_bit <<= 1; - } - return result; -} - -inline std::vector EnumerateTuningsForPath(Path path, bool benchmark) { - if (benchmark) { - return {Tuning::kAuto}; - } -#if RUY_PLATFORM(ARM) - if (path == Path::kNeon || path == Path::kNeonDotprod) { - return {Tuning::kInOrder, Tuning::kOutOfOrder, Tuning::kAuto}; - } -#endif - return {Tuning::kAuto}; -} - -template -void TestSet::MakePrepackedMatrices() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths); - - // Prepacked matrices are Path-dependent, so create them for each test result. - for (auto& result : results) { - // If this result uses an external path, then skip this entirely. - if (result->path == Path::kNone) { - continue; - } - // Pre-packing doesn't make sense for Path::kReference. - // TODO(silvasean): Make Path::kReference an ExternalPath? - if (result->path == Path::kReference) { - continue; - } - - // Determine whether we should create/use prepacked matrices. - if (benchmark) { - // For benchmarking, do as requested. - result->use_prepacked_lhs = benchmark_prepack_lhs; - result->use_prepacked_rhs = benchmark_prepack_rhs; - } else { - // When testing, randomly pre-pack sometimes. But don't do it too often. - result->use_prepacked_lhs = (global_random_engine()() & 7) == 0; - result->use_prepacked_rhs = (global_random_engine()() & 7) == 0; - } - - // Create the pre-packed matrices. - PrepackedMatrix* prepacked_lhs_ptr = - result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; - PrepackedMatrix* prepacked_rhs_ptr = - result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; - auto alloc_fn = [&result](std::size_t num_bytes) { - return result->allocator.AllocateBytes(num_bytes); - }; - // Use a dst with a null data pointer to check that the pre-packing - // invocation doesn't write into it. - Matrix null_data_dst = result->storage_matrix.matrix; - null_data_dst.data = nullptr; - GlobalContext().SetRuntimeEnabledPaths(result->path); - PrePackForMul(lhs.matrix, rhs.matrix, spec, &GlobalContext(), - &null_data_dst, prepacked_lhs_ptr, - prepacked_rhs_ptr, alloc_fn); - RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); - } - - life_stage = LifeStage::kHasPrepackedMatrices; -} - -template -void TestSet::MakeResultPaths() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams); - - Path paths_bitfield = static_cast(GetHexIntEnvVarOrZero("PATHS")); - - if (paths_bitfield == Path::kNone) { - // Use a dummy Context just to perform the resolution of specific runtime - // enabled paths. - Context context; - paths_bitfield = context.GetRuntimeEnabledPaths(); - } - - // Trim bits that don't correspond to a compiled path, - // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all - // those bits exist as actual paths. - paths_bitfield = paths_bitfield & kAllPaths; - RUY_CHECK_NE(paths_bitfield, Path::kNone); - paths = PathsBitfieldAsVector(paths_bitfield); - -#ifdef RUY_TEST_EXTERNAL_PATHS - - using TestSetType = TestSet; - - if (!GetBoolEnvVarOrFalse("NOEXT")) { - if (SupportsGemmlowp::kValue) { -#ifdef GEMMLOWP_SSE4 - const bool gemmlowp_supported = !spec.multiplier_fixedpoint_perchannel; -#else - const bool gemmlowp_supported = true; -#endif - if (gemmlowp_supported) { - external_paths.push_back(ExternalPath::kGemmlowp); - } - } - if (UsesSingleScalarType::kValue && - std::is_floating_point::value) { - external_paths.push_back(ExternalPath::kEigen); - if (layout_style == LayoutStyle::kPackedLinear) { - external_paths.push_back(ExternalPath::kEigenTensor); - } -// We link against a generic BLAS target that only maps to OpenBLAS on specific -// architectures. -#if RUY_PLATFORM(ARM_32) || RUY_PLATFORM(ARM_64) - // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded - // and multi-threaded benchmark results. - if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) { - external_paths.push_back(ExternalPath::kOpenBlas); - } -#endif - } - } - -#endif // RUY_TEST_EXTERNAL_PATHS - - for (Path path : paths) { - for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) { - results.emplace_back(new TestResultType); - TestResultType& result = *results.back(); - result.path = path; - result.tuning = tuning; - MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, - RandomRange::kGeneral, &result.storage_matrix); - } - } - - for (ExternalPath external_path : external_paths) { - results.emplace_back(new TestResultType); - TestResultType& result = *results.back(); - result.external_path = external_path; - MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, - RandomRange::kGeneral, &result.storage_matrix); - } - - life_stage = LifeStage::kHasResultPaths; -} - -template -void TestSet::EvalResult( - TestResult* result) { - RUY_CHECK(result->path != Path::kNone || - result->external_path != ExternalPath::kNone); - if (result->path != Path::kNone) { - EvalRuy(result); - } else { -#ifdef RUY_TEST_EXTERNAL_PATHS - using TestSetType = TestSet; - EvalExternalPath(this, result); -#endif - } - const std::string& pathname = PathName(*result); - if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) == - CoveredPaths()->end()) { - CoveredPaths()->push_back(pathname); - } -} - -using f32 = float; -using f64 = double; -using u8 = std::uint8_t; -using i8 = std::int8_t; -using u16 = std::uint16_t; -using i16 = std::int16_t; -using u32 = std::uint32_t; -using i32 = std::int32_t; -using u64 = std::uint64_t; -using i64 = std::int64_t; - -template -const char* TypeName() { - return nullptr; -} - -#define RUY_TYPENAME(TYPE) \ - template <> \ - const char* TypeName() { \ - return #TYPE; \ - } - -RUY_TYPENAME(f32) -RUY_TYPENAME(f64) -RUY_TYPENAME(u8) -RUY_TYPENAME(i8) -RUY_TYPENAME(u16) -RUY_TYPENAME(i16) -RUY_TYPENAME(u32) -RUY_TYPENAME(i32) -RUY_TYPENAME(u64) -RUY_TYPENAME(i64) - -#undef RUY_TYPENAME - -template -const char* SymmetryName(const Matrix& matrix) { - if (matrix.zero_point == SymmetricZeroPoint()) { - return "symm"; - } else { - return "asymm"; - } -} - -template -int StorageSize(const Matrix& matrix) { - return sizeof(Scalar) * FlatSize(matrix.layout); -} - -// Helper that replicates a buffer and gives out pointers to the replicas. -// This is useful when one wants to traverse data so that it is cold in cache. -// By having a sufficiently large value of num_repeats, one can ensure that the -// working set covered by the replicas is greater than the cache size. -template -class RepeatedBuffer { - public: - RepeatedBuffer() = default; - void Init(const T* elems, std::size_t num_elems, int num_repeats) { - buffers_.clear(); - allocator_.FreeAll(); - for (int i = 0; i < num_repeats; i++) { - T* p; - allocator_.Allocate(num_elems, &p); - memcpy(p, elems, num_elems * sizeof(T)); - buffers_.push_back(p); - } - } - T* Next() { - T* ret = buffers_[current_]; - current_ = (current_ + 1) % buffers_.size(); - return ret; - } - - private: - Allocator allocator_; - std::vector buffers_; - int current_ = 0; -}; - -template -void TestSet::Benchmark( - TestResult* result) { - using DstScalar = typename SpecType::DstScalar; - - const bool cold = getenv("RUY_BENCHMARK_COLD"); - LhsScalar* orig_lhs_data = lhs.matrix.data.get(); - RhsScalar* orig_rhs_data = rhs.matrix.data.get(); - DstScalar* orig_dst_data = result->storage_matrix.matrix.data.get(); - void* orig_prepacked_lhs_data = result->prepacked_lhs.data; - void* orig_prepacked_rhs_data = result->prepacked_rhs.data; - - int num_matmul_sets = 0; - - RepeatedBuffer cold_lhs; - RepeatedBuffer cold_rhs; - RepeatedBuffer cold_dst; - RepeatedBuffer cold_prepacked_lhs; - RepeatedBuffer cold_prepacked_rhs; - - if (cold) { - const int kWorkingSetSize = 100 << 20; - const int each_matmul_set_size = StorageSize(lhs.matrix) + - StorageSize(rhs.matrix) + - StorageSize(result->storage_matrix.matrix); - num_matmul_sets = - (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size; - - cold_lhs.Init(lhs.matrix.data.get(), FlatSize(lhs.matrix.layout), - num_matmul_sets); - cold_rhs.Init(rhs.matrix.data.get(), FlatSize(rhs.matrix.layout), - num_matmul_sets); - cold_dst.Init(result->storage_matrix.matrix.data.get(), - FlatSize(result->storage_matrix.matrix.layout), - num_matmul_sets); - if (benchmark_prepack_lhs) { - cold_prepacked_lhs.Init(static_cast(result->prepacked_lhs.data), - result->prepacked_lhs.data_size, num_matmul_sets); - } - if (benchmark_prepack_rhs) { - cold_prepacked_rhs.Init(static_cast(result->prepacked_rhs.data), - result->prepacked_rhs.data_size, num_matmul_sets); - } - } - const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU"); - int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS"); - if (!repeats) { - repeats = 4; - } - float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS"); - if (!benchmark_min_secs) { - benchmark_min_secs = 0.5; - } -#ifdef RUY_PROFILER - { - const char* lhstype = TypeName(); - const char* lhssymm = SymmetryName(lhs.matrix); - const char* rhstype = TypeName(); - const char* rhssymm = SymmetryName(rhs.matrix); - - printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n", - PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm, - rhstype, rhssymm); - ruy::profiler::ScopeProfile profile; -#endif - - float latency = std::numeric_limits::infinity(); - float l1_refill_rate = std::numeric_limits::infinity(); - float l2_refill_rate = std::numeric_limits::infinity(); - float l3_refill_rate = std::numeric_limits::infinity(); - float l1tlb_refill_rate = std::numeric_limits::infinity(); - float l2tlb_refill_rate = std::numeric_limits::infinity(); - float mispred_rate = std::numeric_limits::infinity(); - float frontend_stall_rate = std::numeric_limits::infinity(); - float backend_stall_rate = std::numeric_limits::infinity(); - - for (int repeat = 0; repeat < repeats; repeat++) { - auto& pmu_events = GlobalPmuEvents(); - if (record_pmu) { - pmu_events.StartRecording(); - } - TimePoint time_start = Now(); - TimePoint t = time_start; - int iters = 0; - int iters_at_a_time = 1; - while (ToFloatSeconds(t - time_start) < benchmark_min_secs) { - for (int i = 0; i < iters_at_a_time; i++) { - if (cold) { - lhs.matrix.data = cold_lhs.Next(); - rhs.matrix.data = cold_rhs.Next(); - result->storage_matrix.matrix.data = cold_dst.Next(); - if (benchmark_prepack_lhs) { - result->prepacked_lhs.data = cold_prepacked_lhs.Next(); - } - if (benchmark_prepack_rhs) { - result->prepacked_rhs.data = cold_prepacked_rhs.Next(); - } - } - EvalResult(result); - iters++; - } - iters_at_a_time *= 2; - t = Now(); - } - latency = std::min( - latency, static_cast(ToFloatSeconds(t - time_start) / iters)); - if (record_pmu) { - pmu_events.StopRecording(); - const float normalization_factor = - 1.0f / (static_cast(iters) * rows * cols * depth); - l1_refill_rate = std::min( - l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor); - l2_refill_rate = std::min( - l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor); - l3_refill_rate = std::min( - l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor); - l1tlb_refill_rate = - std::min(l1tlb_refill_rate, - pmu_events.L1TLBRefillCount() * normalization_factor); - l2tlb_refill_rate = - std::min(l2tlb_refill_rate, - pmu_events.L2TLBRefillCount() * normalization_factor); - mispred_rate = - std::min(mispred_rate, pmu_events.BranchMispredictionCount() * - normalization_factor); - frontend_stall_rate = - std::min(frontend_stall_rate, - pmu_events.FrontendStallCount() * normalization_factor); - backend_stall_rate = - std::min(backend_stall_rate, - pmu_events.BackendStallCount() * normalization_factor); - } - } - result->latency = latency; - if (record_pmu) { - result->l1_refill_rate = l1_refill_rate; - result->l2_refill_rate = l2_refill_rate; - result->l3_refill_rate = l3_refill_rate; - result->l1tlb_refill_rate = l1tlb_refill_rate; - result->l2tlb_refill_rate = l2tlb_refill_rate; - result->mispred_rate = mispred_rate; - result->frontend_stall_rate = frontend_stall_rate; - result->backend_stall_rate = backend_stall_rate; - } - -#ifdef RUY_PROFILER - } - fflush(stdout); -#endif - - if (cold) { - lhs.matrix.data = orig_lhs_data; - rhs.matrix.data = orig_rhs_data; - memcpy(orig_dst_data, result->storage_matrix.matrix.data.get(), - StorageSize(result->storage_matrix.matrix)); - result->storage_matrix.matrix.data = orig_dst_data; - result->prepacked_lhs.data = orig_prepacked_lhs_data; - result->prepacked_rhs.data = orig_prepacked_rhs_data; - } -} - -template -void TestSet::Eval() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasPrepackedMatrices); - for (auto& result : results) { - if (benchmark) { - Benchmark(result.get()); - } else { - EvalResult(result.get()); - } - } - life_stage = LifeStage::kEvaluated; -} - -template -std::string DumpRegion(const Matrix& matrix, int center_row, - int center_col) { - static constexpr int kRadius = 20; - int first_row = std::max(0, center_row - kRadius); - int last_row = std::min(matrix.layout.rows - 1, center_row + kRadius); - int first_col = std::max(0, center_col - kRadius); - int last_col = std::min(matrix.layout.cols - 1, center_col + kRadius); - std::ostringstream stream; - for (int row = first_row; row <= last_row; row++) { - for (int col = first_col; col <= last_col; col++) { - stream << static_cast(Element(matrix, row, col)) << " "; - } - stream << "\n"; - } - return stream.str(); -} - -template -void TestSet::VerifyTestResults() const { - const int depth = lhs.matrix.layout.cols; - for (int i = 0; i < results.size() - 1; i++) { - if (!Agree(*results[i], *results[i + 1], depth)) { - std::string paths_in_agreement; - paths_in_agreement.append(PathName(*results[0])); - for (int j = 1; j <= i; j++) { - paths_in_agreement.append(", "); - paths_in_agreement.append(PathName(*results[j])); - } - ErrorAnalysis error_analysis; - AnalyzeTestError(*this, i + 1, &error_analysis); - std::cerr << "Error: path (" << PathName(*results[i + 1]) - << ") disagrees with the other paths (" << paths_in_agreement - << "), which agree with each other." << std::endl; - std::cerr << "Shape: rows = " << rows << ", cols = " << cols - << ", depth = " << depth << std::endl; - std::cerr << "Stats of the good result matrix: " - << StatsAsString(error_analysis.stats_good) << std::endl; - std::cerr << "Stats of the bad result matrix: " - << StatsAsString(error_analysis.stats_bad) << std::endl; - if (error_analysis.error_rows.size() < rows) { - std::cerr << "Rows containing errors: " - << Join(error_analysis.error_rows) << std::endl; - } else { - std::cerr << "Errors found in ALL rows." << std::endl; - } - if (error_analysis.error_cols.size() < cols) { - std::cerr << "Cols containing errors: " - << Join(error_analysis.error_cols) << std::endl; - } else { - std::cerr << "Errors found in ALL cols." << std::endl; - } - std::cerr << "The first error occurs at row " - << error_analysis.row_of_first_error << ", col " - << error_analysis.col_of_first_error << std::endl; - std::cerr << "Good value: " << error_analysis.first_error_good_value - << std::endl; - std::cerr << "Bad value : " << error_analysis.first_error_bad_value - << std::endl; - std::cerr << "Region of Good result matrix around first error:\n\n" - << DumpRegion(results[0]->storage_matrix.matrix, - error_analysis.row_of_first_error, - error_analysis.col_of_first_error) - << std::endl; - std::cerr << "Region of Bad result matrix around first error:\n\n" - << DumpRegion(results[i + 1]->storage_matrix.matrix, - error_analysis.row_of_first_error, - error_analysis.col_of_first_error) - << std::endl; - RUY_CHECK(false); - } - } -} - -template -void TestSet::Verify() { - RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated); - if (expected_outcome == ExpectedOutcome::kSuccess) { - VerifyTestResults(); - } - life_stage = LifeStage::kFinal; -} - -template -void TestRCC(int rows, int depth, int cols, ExpectedOutcome expected_outcome) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); -} - -template -void TestRCC(int rows, int depth, int cols) { - TestRCC(rows, depth, cols, ExpectedOutcome::kSuccess); -} - -template -void TestNonRCC(int rows, int depth, int cols, - ExpectedOutcome expected_outcome) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = Order::kColMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); -} - -template -void TestLinearAllOrders(int rows, int depth, int cols, - ExpectedOutcome expected_outcome) { - const std::vector orders{Order::kColMajor, Order::kRowMajor}; - - for (Order lhs_order : orders) { - for (Order rhs_order : orders) { - for (Order dst_order : orders) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = lhs_order; - test_set.rhs_order = rhs_order; - test_set.dst_order = dst_order; - test_set.layout_style = LayoutStyle::kLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); - } - } - } -} - -template -void TestLinearAllOrders(int rows, int depth, int cols) { - TestLinearAllOrders(rows, depth, cols, - ExpectedOutcome::kSuccess); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/test_fast.cc b/tensorflow/lite/experimental/ruy/ruy/test_fast.cc deleted file mode 100644 index 6b7026530ac..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test_fast.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test contains cheap test cases, completes in a few seconds. - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -using TestSetType = - TestSet>; - -TEST(RuyTest, TestSquareMuls) { - const std::vector sizes{ - // small sizes - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - // multiplies of 16 - 16, - 32, - 48, - 64, - // pot-minus-1 sizes - 15, - 31, - 63, - // pot-plus-1 sizes - 17, - 33, - 65, - }; - - for (int size : sizes) { - TestRCC(size, size, size); - TestLinearAllOrders(size, size, size); - } -} - -TEST(RuyTest, TestMiscMuls) { - const int shapes[][3] = { - {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17}, - {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76}, - {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16}, - {16, 81, 16}, {16, 95, 16}, {3, 128, 5}}; - for (const auto& shape : shapes) { - TestLinearAllOrders(shape[0], shape[1], shape[2]); - } -} - -TEST(RuyTest, TestDeepMuls) { - // TODO(b/137649322): clarify what's the max allowed matrix size. - TestRCC(1, 32767, 1); - TestLinearAllOrders(5, 5001, 4); - TestLinearAllOrders(9, 1025, 10); -} - -TEST(RuyTest, TestShallowMuls) { - TestLinearAllOrders(101, 1, 103); - TestLinearAllOrders(71, 2, 53); - TestLinearAllOrders(51, 3, 73); - TestLinearAllOrders(51, 4, 43); -} - -TEST(RuyTest, TestNarrowMuls) { - for (int width : {1, 2, 3, 4, 5, 8}) { - TestLinearAllOrders(width, 12, 13); - TestLinearAllOrders(15, 19, width); - TestLinearAllOrders(width, 123, 137); - TestLinearAllOrders(158, 119, width); - } -} - -TEST(RuyTest, TestGEMV) { - for (int size = 1; size < 1024; size *= 2) { - for (int depth = 1; depth < 500; depth += 47) { - TestLinearAllOrders(size, depth, 1); - } - } - TestLinearAllOrders(5, 5001, 1); - TestLinearAllOrders(8193, 17, 1); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/test_slow.cc b/tensorflow/lite/experimental/ruy/ruy/test_slow.cc deleted file mode 100644 index 7e7292cd503..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test_slow.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test contains more expensive test cases. - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -using TestSetType = - TestSet>; - -TEST(RuyTest, TestBigNarrowMuls) { - for (int width : {1, 2, 3, 4, 5, 8}) { - TestRCC(width, 401, 601); - TestRCC(587, 443, width); - } - TestRCC(7, 45984, - 5); // Large enough to trigger row-sum overflows. - TestRCC(512, 256, 16); -} - -TEST(RuyTest, TestBigShallowMuls) { - TestLinearAllOrders(501, 1, 321); - TestLinearAllOrders(301, 5, 403); - TestLinearAllOrders(256, 32, 512); -} - -TEST(RuyTest, TestBigMuls) { - TestRCC(225, 303, 199); - TestLinearAllOrders(256, 192, 128); -} - -TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) { - // Important to test some power-of-two depths: that's when the - // RUY_AVOID_ALIASING optimization kicks in and makes packed matrices - // strided, exposing bugs in kernels mixing up size and stride. - // Moreover, it's important that the test matrices be sufficiently wide - // that they will result in multiple blocks, exposing bugs in the - // computation of the base address of each block. - TestLinearAllOrders(70, 1024, 80); - TestLinearAllOrders(60, 2048, 70); - TestLinearAllOrders(40, 4096, 50); -} - -TEST(RuyTest, TestGEMV) { - for (int size = 1025; size <= 1409; size += 384) { - for (int depth = 350; depth < 500; depth += 47) { - TestLinearAllOrders(size, depth, 1); - } - } -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc b/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc deleted file mode 100644 index 6f5a88c833a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/test_special_specs.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test covers non-basic specs. - -#include "tensorflow/lite/experimental/ruy/ruy/test.h" - -namespace ruy { - -template -struct LoopStructureSpec : BasicSpec { - static constexpr LoopStructure kLoopStructure = tLoopStructure; -}; - -template -struct ZeroPointSupportSpec : BasicSpec { - static constexpr ZeroPointSupport kZeroPointSupport = tZeroPointSupport; -}; - -template -struct RCCSpec : BasicSpec { - static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kRCC; -}; - -template -struct StandardCppKernelLayoutSpec : BasicSpec { - using StandardCppKernelLhsLayout = LhsKernelLayout; - using StandardCppKernelRhsLayout = RhsKernelLayout; - static int local_data_cache_size() { return 1; } - static int shared_data_cache_size() { return 1; } -}; - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -template -void TestLoopStructure() { - using SpecType = LoopStructureSpec; - using TestSetType = TestSet; - for (int size = 1; size < 10; size++) { - TestLinearAllOrders(size, size, size); - } - TestLinearAllOrders(3, 5, 78); - TestLinearAllOrders(19, 91, 7); - TestLinearAllOrders(71, 26, 44); - TestLinearAllOrders(81, 93, 72); -} - -TEST(TestSpecialSpecs, LoopStructure) { - static_assert(BasicSpec::kLoopStructure == - LoopStructure::kAuto, - ""); - static_assert(BasicSpec::kLoopStructure == LoopStructure::kAuto, - ""); - TestLoopStructure(); - TestLoopStructure(); -} - -template -void TestZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, - DstScalar dst_zero_point, - ExpectedOutcome expected_outcome) { - using SpecType = - ZeroPointSupportSpec; - using TestSetType = TestSet; - TestSetType test_set; - test_set.rows = 11; - test_set.depth = 12; - test_set.cols = 13; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.lhs_zero_point = lhs_zero_point; - test_set.rhs_zero_point = rhs_zero_point; - test_set.dst_zero_point = dst_zero_point; - test_set.use_specified_zero_points = true; - test_set.Run(); -} - -TEST(TestSpecialSpecs, ZeroPointSupport) { - // Sanity check - RUY_CHECK_EQ(SymmetricZeroPoint(), 128); - RUY_CHECK_EQ(SymmetricZeroPoint(), 0); - - if (std::is_floating_point::value) { - return; - } - - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint() - 1, SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint() + 1, SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kDeath); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint() + 1, - SymmetricZeroPoint(), ExpectedOutcome::kDeath); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint() - 1, ExpectedOutcome::kDeath); -} - -TEST(TestSpecialSpecs, RCC) { - using RCCSpec = RCCSpec; - using RCCTestSet = TestSet; - TestRCC(81, 93, 72); - TestNonRCC(81, 93, 72, ExpectedOutcome::kDeath); -} - -template -void TestStandardCppKernelLayout() { - using SpecType = - StandardCppKernelLayoutSpec; - using TestSetType = TestSet; - for (int size = 1; size < 10; size++) { - TestLinearAllOrders(size, size, size); - } - TestLinearAllOrders(87, 34, 56); - TestLinearAllOrders(123, 234, 78); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutTrivial1x1) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutSquare4x4) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutRectangular4x8) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc b/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc deleted file mode 100644 index eb86a1fbf38..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/thread_pool.cc +++ /dev/null @@ -1,200 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" - -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include -#include -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -namespace ruy { - -// A worker thread. -class Thread { - public: - enum class State { - Startup, // The initial state before the thread main loop runs. - Ready, // Is not working, has not yet received new work to do. - HasWork, // Has work to do. - ExitAsSoonAsPossible // Should exit at earliest convenience. - }; - - explicit Thread(BlockingCounter* counter_to_decrement_when_ready) - : task_(nullptr), - state_(State::Startup), - counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { - thread_.reset(new std::thread(ThreadFunc, this)); - } - - ~Thread() { - ChangeState(State::ExitAsSoonAsPossible); - thread_->join(); - } - - // Changes State; may be called from either the worker thread - // or the master thread; however, not all state transitions are legal, - // which is guarded by assertions. - // - // The Task argument is to be used only with new_state==HasWork. - // It specifies the Task being handed to this Thread. - void ChangeState(State new_state, Task* task = nullptr) { - state_mutex_.lock(); - State old_state = state_.load(std::memory_order_relaxed); - RUY_DCHECK_NE(old_state, new_state); - switch (old_state) { - case State::Startup: - RUY_DCHECK_EQ(new_state, State::Ready); - break; - case State::Ready: - RUY_DCHECK(new_state == State::HasWork || - new_state == State::ExitAsSoonAsPossible); - break; - case State::HasWork: - RUY_DCHECK(new_state == State::Ready || - new_state == State::ExitAsSoonAsPossible); - break; - default: - abort(); - } - switch (new_state) { - case State::Ready: - if (task_) { - // Doing work is part of reverting to 'ready' state. - task_->Run(); - task_ = nullptr; - } - break; - case State::HasWork: - RUY_DCHECK(!task_); - task_ = task; - break; - default: - break; - } - state_.store(new_state, std::memory_order_relaxed); - state_cond_.notify_all(); - state_mutex_.unlock(); - if (new_state == State::Ready) { - counter_to_decrement_when_ready_->DecrementCount(); - } - } - - static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } - - // Called by the master thead to give this thread work to do. - void StartWork(Task* task) { ChangeState(State::HasWork, task); } - - private: - // Thread entry point. - void ThreadFuncImpl() { - ChangeState(State::Ready); - - // Thread main loop - while (true) { - // In the 'Ready' state, we have nothing to do but to wait until - // we switch to another state. - const auto& condition = [this]() { - return state_.load(std::memory_order_acquire) != State::Ready; - }; - Wait(condition, &state_cond_, &state_mutex_); - - // Act on new state. - switch (state_.load(std::memory_order_acquire)) { - case State::HasWork: - // Got work to do! So do it, and then revert to 'Ready' state. - ChangeState(State::Ready); - break; - case State::ExitAsSoonAsPossible: - return; - default: - abort(); - } - } - } - - // The underlying thread. - std::unique_ptr thread_; - - // The task to be worked on. - Task* task_; - - // The condition variable and mutex guarding state changes. - std::condition_variable state_cond_; - std::mutex state_mutex_; - - // The state enum tells if we're currently working, waiting for work, etc. - // Its concurrent accesses by the thread and main threads are guarded by - // state_mutex_, and can thus use memory_order_relaxed. This still needs - // to be a std::atomic because we use WaitForVariableChange. - std::atomic state_; - - // pointer to the master's thread BlockingCounter object, to notify the - // master thread of when this thread switches to the 'Ready' state. - BlockingCounter* const counter_to_decrement_when_ready_; -}; - -void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { - RUY_DCHECK_GE(task_count, 1); - - // Case of 1 thread: just run the single task on the current thread. - if (task_count == 1) { - (tasks + 0)->Run(); - return; - } - - // Task #0 will be run on the current thread. - CreateThreads(task_count - 1); - counter_to_decrement_when_ready_.Reset(task_count - 1); - for (int i = 1; i < task_count; i++) { - auto task_address = reinterpret_cast(tasks) + i * stride; - threads_[i - 1]->StartWork(reinterpret_cast(task_address)); - } - - // Execute task #0 immediately on the current thread. - (tasks + 0)->Run(); - - // Wait for the threads submitted above to finish. - counter_to_decrement_when_ready_.Wait(); -} - -// Ensures that the pool has at least the given count of threads. -// If any new thread has to be created, this function waits for it to -// be ready. -void ThreadPool::CreateThreads(int threads_count) { - if (threads_.size() >= threads_count) { - return; - } - counter_to_decrement_when_ready_.Reset(threads_count - threads_.size()); - while (threads_.size() < threads_count) { - threads_.push_back(new Thread(&counter_to_decrement_when_ready_)); - } - counter_to_decrement_when_ready_.Wait(); -} - -ThreadPool::~ThreadPool() { - for (auto w : threads_) { - delete w; - } -} - -} // end namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/thread_pool.h b/tensorflow/lite/experimental/ruy/ruy/thread_pool.h deleted file mode 100644 index 5504bd80614..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/thread_pool.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 -// license. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/blocking_counter.h" - -namespace ruy { - -// A workload for a thread. -struct Task { - virtual ~Task() {} - virtual void Run() = 0; -}; - -class Thread; - -// A simple pool of threads, that only allows the very -// specific parallelization pattern that we use here: -// One thread, which we call the 'main thread', calls Execute, distributing -// a Task each to N threads, being N-1 'worker threads' and the main thread -// itself. After the main thread has completed its own Task, it waits for -// the worker threads to have all completed. That is the only synchronization -// performed by this ThreadPool. -// -// In particular, there is a naive 1:1 mapping of Tasks to threads. -// This ThreadPool considers it outside of its own scope to try to work -// with fewer threads than there are Tasks. The idea is that such N:M mappings -// of tasks to threads can be implemented as a higher-level feature on top of -// the present low-level 1:1 threadpool. For example, a user might have a -// Task subclass referencing a shared atomic counter indexing into a vector of -// finer-granularity subtasks. Different threads would then concurrently -// increment this atomic counter, getting each their own subtasks to work on. -// That approach is the one used in ruy's multi-thread matrix multiplication -// implementation --- see ruy's TrMulTask. -class ThreadPool { - public: - ThreadPool() {} - - ~ThreadPool(); - - // Executes task_count tasks on task_count threads. - // Grows the threadpool as needed to have at least (task_count-1) threads. - // The 0-th task is run on the thread on which Execute is called: that - // is by definition what we call the "main thread". Synchronization of all - // threads is performed before this function returns. - // - // As explained in the class comment, there is a 1:1 mapping of tasks to - // threads. If you need something smarter than that, for instance if you - // want to run an unbounded number of tasks on a bounded number of threads, - // then you need something higher-level than this ThreadPool, that can - // be layered on top of it by appropriately subclassing Tasks. - // - // TaskType must be a subclass of ruy::Task. That is implicitly guarded by - // the static_cast in this inline implementation. - template - void Execute(int task_count, TaskType* tasks) { - ExecuteImpl(task_count, sizeof(TaskType), static_cast(tasks)); - } - - private: - // Ensures that the pool has at least the given count of threads. - // If any new thread has to be created, this function waits for it to - // be ready. - void CreateThreads(int threads_count); - - // Non-templatized implementation of the public Execute method. - // See the inline implementation of Execute for how this is used. - void ExecuteImpl(int task_count, int stride, Task* tasks); - - // copy construction disallowed - ThreadPool(const ThreadPool&) = delete; - - // The threads in this pool. They are owned by the pool: - // the pool creates threads and destroys them in its destructor. - std::vector threads_; - - // The BlockingCounter used to wait for the threads. - BlockingCounter counter_to_decrement_when_ready_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/time.h b/tensorflow/lite/experimental/ruy/ruy/time.h deleted file mode 100644 index 9dba75eb4c5..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/time.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ - -#include // NOLINT(build/c++11) -#include // IWYU pragma: keep -#include // NOLINT(build/c++11) - -#ifdef __linux__ -#include -// IWYU pragma: no_include - -#include -#endif - -namespace ruy { - -using InternalDefaultClock = std::chrono::steady_clock; - -using TimePoint = InternalDefaultClock::time_point; -using Duration = InternalDefaultClock::duration; - -template -Duration DurationFromSeconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -template -Duration DurationFromMilliseconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -template -Duration DurationFromNanoseconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -inline float ToFloatSeconds(const Duration& duration) { - return std::chrono::duration_cast>(duration) - .count(); -} - -inline std::int64_t ToInt64Nanoseconds(const Duration& duration) { - return std::chrono::duration_cast< - std::chrono::duration>(duration) - .count(); -} - -inline TimePoint Now() { return InternalDefaultClock::now(); } - -inline TimePoint CoarseNow() { -#ifdef __linux__ - timespec t; - clock_gettime(CLOCK_MONOTONIC_COARSE, &t); - return TimePoint( - DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec)); -#else - return Now(); -#endif -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/trace.cc b/tensorflow/lite/experimental/ruy/ruy/trace.cc deleted file mode 100644 index 806f6ec2cf2..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trace.cc +++ /dev/null @@ -1,325 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/trace.h" - -#include -#include // IWYU pragma: keep -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { - -#ifdef RUY_TRACE - -enum class TraceEvent : std::uint8_t { - kNone, - kThreadStart, - kThreadLoopStart, - kThreadEnd, - kBlockReserved, - kBlockPackedLhs, - kBlockPackedRhs, - kBlockFinished -}; - -struct TraceEntry { - TimePoint time_point; - TraceEvent event; - // ruy-internal thread id i.e. contiguous index into array of threads, - // with 0 designating the main thread. - std::uint16_t thread_id = 0; - // Additional parameters whose meaning depends on the 'event' type. - std::uint32_t params[1]; -}; - -struct Trace { - BlockMap block_map; - // During recording, to avoid having to use locks or atomics, we let - // each thread append to its own specific vector. - std::vector> thread_specific_entries; - // Global vector of entries into which we coalesce thread_specific_entries - // after recording is finished, when dumping a trace. See - // AggregateThreadSpecificEntries. - std::vector entries; - TimePoint time_start; - TimePoint time_execute; - TimePoint time_end; -}; - -namespace { - -// Coalesce Trace::thread_specific_entries into Trace::entries. -void AggregateThreadSpecificEntries(Trace* trace) { - RUY_CHECK(trace->entries.empty()); - for (auto& thread_specific_entries_vector : trace->thread_specific_entries) { - for (const TraceEntry& entry : thread_specific_entries_vector) { - trace->entries.push_back(entry); - } - thread_specific_entries_vector.clear(); - } -} - -// Sort Trace::entries by ascending time. In case of equal timepoints, -// sort by some semi-arbitrary ordering of event types. -void Sort(Trace* trace) { - std::sort(std::begin(trace->entries), std::end(trace->entries), - [](const TraceEntry& a, const TraceEntry& b) -> bool { - return a.time_point < b.time_point || - (a.time_point == b.time_point && - static_cast(a.event) < static_cast(b.event)); - }); -} - -// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have -// already been called on it. -// -// On some architectures long long ints are not same as std::int64_t, and -// time is printed as %lld, so static_casts are necessary. -void Dump(const Trace& trace) { - const char* trace_filename = getenv("RUY_TRACE_FILE"); - FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr; - if (!trace_file) { - fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename, - errno); - RUY_CHECK(false); - } - fprintf(trace_file, "thread_count:%d\n", trace.block_map.thread_count); - fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]); - fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]); - fprintf(trace_file, "Execute: %lld\n", - static_cast( - ToInt64Nanoseconds(trace.time_execute - trace.time_start))); - for (const TraceEntry& entry : trace.entries) { - long long int time = static_cast( - ToInt64Nanoseconds(entry.time_point - trace.time_start)); - switch (entry.event) { - case TraceEvent::kThreadStart: - fprintf(trace_file, "ThreadStart: %lld, %d\n", time, entry.thread_id); - break; - case TraceEvent::kThreadLoopStart: - fprintf(trace_file, "ThreadLoopStart: %lld, %d\n", time, - entry.thread_id); - break; - case TraceEvent::kThreadEnd: - fprintf(trace_file, "ThreadEnd: %lld, %d\n", time, entry.thread_id); - break; - case TraceEvent::kBlockReserved: { - std::uint32_t block_id = entry.params[0]; - SidePair block; - GetBlockByIndex(trace.block_map, block_id, &block); - SidePair start, end; - GetBlockMatrixCoords(trace.block_map, block, &start, &end); - fprintf(trace_file, - "BlockReserved: %lld, %d, %d, %d, %d, %d, %d, %d, %d\n", time, - entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs], - start[Side::kLhs], start[Side::kRhs], end[Side::kLhs], - end[Side::kRhs]); - break; - } - case TraceEvent::kBlockPackedLhs: { - std::uint32_t block = entry.params[0]; - int start, end; - GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end); - fprintf(trace_file, "BlockPackedLhs: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block, start, end); - break; - } - case TraceEvent::kBlockPackedRhs: { - std::uint32_t block = entry.params[0]; - int start, end; - GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end); - fprintf(trace_file, "BlockPackedRhs: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block, start, end); - break; - } - case TraceEvent::kBlockFinished: { - std::uint32_t block_id = entry.params[0]; - SidePair block; - GetBlockByIndex(trace.block_map, block_id, &block); - fprintf(trace_file, "BlockFinished: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block_id, block[Side::kLhs], - block[Side::kRhs]); - break; - } - default: - RUY_CHECK(false); - } - } - fprintf(trace_file, "End: %lld\n", - static_cast( - ToInt64Nanoseconds(trace.time_end - trace.time_start))); - if (trace_filename) { - fclose(trace_file); - } -} - -} // anonymous namespace - -// Get a Trace object to record to, or null of tracing is not enabled. -Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) { - if (!tracing->initialized) { - tracing->initialized = true; - tracing->enabled = getenv("RUY_TRACE"); - if (!tracing->enabled) { - return nullptr; - } - if (getenv("RUY_TRACE_FILTER_ROWS")) { - tracing->filter_shape_rows = std::stoi(getenv("RUY_TRACE_FILTER_ROWS")); - } - if (getenv("RUY_TRACE_FILTER_DEPTH")) { - tracing->filter_shape_depth = std::stoi(getenv("RUY_TRACE_FILTER_DEPTH")); - } - if (getenv("RUY_TRACE_FILTER_COLS")) { - tracing->filter_shape_cols = std::stoi(getenv("RUY_TRACE_FILTER_COLS")); - } - } - if (!tracing->enabled) { - return nullptr; - } - if (tracing->filter_shape_rows && rows != tracing->filter_shape_rows) { - return nullptr; - } - if (tracing->filter_shape_depth && depth != tracing->filter_shape_depth) { - return nullptr; - } - if (tracing->filter_shape_cols && cols != tracing->filter_shape_cols) { - return nullptr; - } - // Delete any existing trace. - delete tracing->trace; - // Create a new one. - tracing->trace = new Trace; - return tracing->trace; -} - -// The trace recorded on a context is finalized and dumped by -// this TracingContext destructor. -// -// The idea of dumping on context destructor is that typically one wants to -// run many matrix multiplications, e.g. to hit a steady state in terms of -// performance characteristics, but only trace the last repetition of the -// workload, when that steady state was attained. -TracingContext::~TracingContext() { - if (trace) { - AggregateThreadSpecificEntries(trace); - Sort(trace); - Dump(*trace); - } - delete trace; -} - -void TraceRecordStart(Trace* trace) { - if (trace) { - trace->time_start = Now(); - } -} - -void TraceRecordExecute(const BlockMap& block_map, Trace* trace) { - if (trace) { - trace->time_execute = Now(); - trace->block_map = block_map; - trace->thread_specific_entries.resize(block_map.thread_count); - for (int thread = 0; thread < block_map.thread_count; thread++) { - trace->thread_specific_entries[thread].clear(); - // Reserve some large size to avoid frequent heap allocations - // affecting the recorded timings. - trace->thread_specific_entries[thread].reserve(16384); - } - } -} - -void TraceRecordEnd(Trace* trace) { - if (trace) { - trace->time_end = Now(); - } -} - -void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadStart; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadLoopStart; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kBlockReserved; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs - : TraceEvent::kBlockPackedRhs; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kBlockFinished; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadEnd; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -#endif - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/trace.h b/tensorflow/lite/experimental/ruy/ruy/trace.h deleted file mode 100644 index 6680438c124..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trace.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ - -#include - -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" - -namespace ruy { - -struct Trace; - -#ifdef RUY_TRACE - -struct TracingContext { - bool initialized = false; - bool enabled = false; - int filter_shape_rows = 0; - int filter_shape_cols = 0; - int filter_shape_depth = 0; - Trace* trace = nullptr; - ~TracingContext(); -}; - -Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols); -void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace); -void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace); -void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace); -void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, - Trace* trace); -void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace); -void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace); -void TraceRecordStart(Trace* trace); -void TraceRecordExecute(const BlockMap& block_map, Trace* trace); -void TraceRecordEnd(Trace* trace); - -#else - -struct TracingContext {}; - -inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; } -inline void TraceRecordThreadStart(std::uint32_t, Trace*) {} -inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {} -inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {} -inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {} -inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {} -inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {} -inline void TraceRecordStart(Trace*) {} -inline void TraceRecordExecute(const BlockMap&, Trace*) {} -inline void TraceRecordEnd(Trace*) {} - -#endif - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/trmul.cc b/tensorflow/lite/experimental/ruy/ruy/trmul.cc deleted file mode 100644 index c3e15a9d628..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trmul.cc +++ /dev/null @@ -1,401 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/trmul.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/ruy/ruy/allocator.h" -#include "tensorflow/lite/experimental/ruy/ruy/block_map.h" -#include "tensorflow/lite/experimental/ruy/ruy/check_macros.h" -#include "tensorflow/lite/experimental/ruy/ruy/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/size_util.h" -#include "tensorflow/lite/experimental/ruy/ruy/spec.h" -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" -#include "tensorflow/lite/experimental/ruy/ruy/trace.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -namespace { - -enum class PackingStatus : std::uint8_t { kNotStarted, kInProgress, kFinished }; - -struct TrMulTask final : Task { - TrMulTask(TrMulParams* params_, const BlockMap& block_map_, - std::atomic* atomic_block_id_, int thread_id_, - bool need_atomics_, - SidePair*> packing_status_, - TuningResolver* tuning_resolver_, Allocator* local_allocator_, - Trace* trace_) - : params(params_), - block_map(block_map_), - atomic_block_id(atomic_block_id_), - thread_id(thread_id_), - need_atomics(need_atomics_), - packing_status(packing_status_), - tuning_resolver(tuning_resolver_), - local_allocator(local_allocator_), - trace(trace_), - local_packed{nullptr, nullptr} {} - - void Run() override { - TraceRecordThreadStart(thread_id, trace); - - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - const int size = NumBlocksPerSide(side, block_map); - local_allocator->Allocate(size, &local_packed[side]); - memset(local_packed[side], 0, size * sizeof(bool)); - } - } - - const int num_blocks = NumBlocks(block_map); - - const Tuning tuning = tuning_resolver->Resolve(); - - TraceRecordThreadLoopStart(thread_id, trace); - - SidePair block; - SidePair start; - SidePair end; - - // Each thread starts by initially reserving the block whose id - // is the thread id. - int block_id = thread_id; - TraceRecordBlockReserved(thread_id, block_id, trace); - - while (block_id < num_blocks) { - // Reserve the next block to handle. In order to hide the latency - // (typically comparable to an access to the level of data cache that - // is shared among CPU cores, e.g. 60 cycles on an ARM CPU as of 2019) - // of this atomic operation, we structure this code so as to avoid - // immediately depending on the `next_n` result. - const int next_block_id = - atomic_block_id->fetch_add(1, std::memory_order_relaxed); - TraceRecordBlockReserved(thread_id, next_block_id, trace); - // Get coordinates of the current block to handle, in "block space". - GetBlockByIndex(block_map, block_id, &block); - // Get coordinates of the current block to handle, in matrix space. - GetBlockMatrixCoords(block_map, block, &start, &end); - // Maybe pack the current LHS/RHS block, if not already packed. - EnsurePacked(block, start, end, tuning); - // Actually do matrix multiplication work - params->RunKernel(tuning, start, end); - TraceRecordBlockFinished(thread_id, block_id, trace); - // Move on to the next block as obtained by the atomic increment - // at the start of this while loop iteration. - block_id = next_block_id; - } - - local_allocator->FreeAll(); - - TraceRecordThreadEnd(thread_id, trace); - } - - private: - // Tries to pack a block, without blocking. - // If the block was already packed, returns true. - // If the block was not started packing, packs it and returns true. - // If the block was being packed by another thread, returns false. - bool TryPack(Side side, int block, int start, int end, Tuning tuning) { - if (params->is_prepacked[side]) { - return true; - } - if (!local_packed[side][block]) { - if (need_atomics) { - // Explanation of this compare_exchange_strong operation: - // This atomically performs all of the following: - // 1. Read `status` with "acquire" memory order. - // * That this read uses "acquire" is because both memory orders - // specified have "acquire" as their read-component. - // 2. Compare (bitwise) with `exchanged_status`. - // 3. If equal, stores the value kInProgress to `status` with "release" - // memory order, and returns true, so we take this 'if' branch. - // * That this store uses "release" is because of the _rel part in - // memory_order_acq_rel passed as the first memory order argument. - // 4. If not equal, stores the loaded value of `status` to - // `exchanged_status` with "relaxed" semantics, and returns false, - // so we take the 'else' branch. - // * That this store uses "relaxed" is because the second memory - // order argument, memory_order_acquire, implies no particular - // store semantics. "relaxed" is acceptable here because this - // stores to a local stack variable. - // - // Rationale for compare_exchange_strong as opposed to - // compare_exchange_weak: - // The spurious-failure case with compare_exchange_weak will actually - // happen a lot here, because the atomic 'status' bytes are stored - // contiguously in arrays and neighboring values will be accessed - // by multiple threads concurrently. On a typical ARM CPU, an exclusives - // reservation granule is 64 bytes, so a lot of false-sharing may - // happen. Using compare_exchange_weak would thus result in often having - // TryPack return 'false' when it could instead have done the packing - // work and returned 'true'. Heuristically, that is not a good thing. - // Moreover, this changes the TryPack contract, loosening it and making - // it harder for the caller to reason about. Finally, the overhead of - // atomic operations is mitigated by the enclosing check on - // local_packed, so maybe the overhead of compare_exchange_strong isn't - // such a problem. But we don't really know for sure, that would be - // interesting to experiment more with. - PackingStatus exchanged_status = PackingStatus::kNotStarted; - std::atomic& status = packing_status[side][block]; - if (status.compare_exchange_strong( - exchanged_status, PackingStatus::kInProgress, - std::memory_order_acq_rel, std::memory_order_acquire)) { - // In this branch, the status was kNotStarted and we just atomically - // changed it to kInProgress as we are about to handle the packing - // ourselves. - params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(thread_id, side, block, trace); - status.store(PackingStatus::kFinished, std::memory_order_release); - } else if (exchanged_status == PackingStatus::kInProgress) { - // Another thread is currently packing this block. - return false; - } - RUY_DCHECK(status.load(std::memory_order_acquire) == - PackingStatus::kFinished); - } else { - // Single-threaded case: no need for expensive atomics, local_packed - // is the truth already. - params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(thread_id, side, block, trace); - } - local_packed[side][block] = true; - } - return true; - } - - // Ensures that both the LHS and RHS blocks required by the specified block - // are packed. In the event that they are already being packed on another - // threads, this function may perform the packing of some other block while - // waiting for that other thread to finish packing the requested block. - void EnsurePacked(const SidePair& block, const SidePair& start, - const SidePair& end, Tuning tuning) { -#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) - SidePair next_runahead_block{block[Side::kLhs] + 1, - block[Side::kRhs] + 1}; - Side next_runahead_side = Side::kLhs; -#endif - while (true) { - bool both_sides_packed = true; - for (Side side : {Side::kLhs, Side::kRhs}) { - both_sides_packed &= - TryPack(side, block[side], start[side], end[side], tuning); - } - if (both_sides_packed) { - break; - } -#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) - const Side runahead_side = next_runahead_side; - const int runahead_block = next_runahead_block[runahead_side]; - next_runahead_side = - next_runahead_side == Side::kLhs ? Side::kRhs : Side::kLhs; - if (runahead_block >= NumBlocksPerSide(runahead_side, block_map)) { - continue; - } - int runahead_block_start, runahead_block_end; - GetBlockMatrixCoords(runahead_side, block_map, runahead_block, - &runahead_block_start, &runahead_block_end); - TryPack(runahead_side, runahead_block, runahead_block_start, - runahead_block_end, tuning); - next_runahead_block[runahead_side] = runahead_block + 1; -#endif - } - } - - TrMulParams* params; - const BlockMap& block_map; - std::atomic* atomic_block_id; - int thread_id; - bool need_atomics; - SidePair*> packing_status; - TuningResolver* tuning_resolver; - Allocator* local_allocator; - Trace* trace; - - // Local indicators of packedness to avoid the overhead of atomic ops. - SidePair local_packed; -}; - -void AllocatePMatrix(Allocator* allocator, PMatrix* packed) { - packed->data = allocator->AllocateBytes(DataSize(*packed)); - packed->sums = allocator->AllocateBytes(SumsSize(*packed)); -} - -int GetThreadCount(Context* context, int rows, int cols, int depth) { -#if RUY_PLATFORM(EMSCRIPTEN) - // b/139927184, std::thread constructor raises exception - return 1; -#endif - // Empirically determined rule for reasonable number of - // threads to use. This is proportional to the number of arithmetic ops - // in this Mul (product of the 3 sizes). - static constexpr int kDivisorLog2 = 15; - const int guess_log2 = std::max( - 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2); - return std::min(1 << guess_log2, context->max_num_threads); -} - -LoopStructure GetLoopStructure(int tentative_thread_count, int rows, int cols, - int depth, int lhs_scalar_size, - int rhs_scalar_size, int local_data_cache_size, - int shared_data_cache_size) { - if (tentative_thread_count == 1) { - const BlockMapTraversalOrder traversal_order = - GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, - local_data_cache_size, shared_data_cache_size); - // If we are in the GEMV case or the block_map would be using linear - // traversal anyway, use the simple loop. - if ((cols == 1) || traversal_order == BlockMapTraversalOrder::kLinear) { - return LoopStructure::kSimple; - } - } - return LoopStructure::kGeneral; -} - -} // namespace - -void TrMul(TrMulParams* params, Context* context) { - profiler::ScopeLabel label( - "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))", - static_cast(params->path), context->max_num_threads, - params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]); - - PMatrix& packed_lhs = params->packed[Side::kLhs]; - PMatrix& packed_rhs = params->packed[Side::kRhs]; - DMatrix& lhs = params->src[Side::kLhs]; - DMatrix& rhs = params->src[Side::kRhs]; - - const int rows = lhs.layout.cols; - const int cols = rhs.layout.cols; - const int depth = lhs.layout.rows; - - const int tentative_thread_count = GetThreadCount(context, rows, cols, depth); - const auto loop_structure = GetLoopStructure( - tentative_thread_count, rows, cols, depth, lhs.data_type.size, - rhs.data_type.size, params->local_data_cache_size, - params->shared_data_cache_size); - Allocator* allocator = context->GetMainAllocator(); - - // Allocate packed matrices - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - AllocatePMatrix(allocator, ¶ms->packed[side]); - } - } - - // Case of running this TrMul as a simple loop. - // This is a good place to start reading this function: all the rest - // of this function is just an optimized, but functionally equivalent, - // version of that. - if (loop_structure == LoopStructure::kSimple) { - profiler::ScopeLabel label_simple("TrMulImpl, simple loop"); - Tuning tuning = context->GetMainThreadTuning(); - - const SidePair origin{0, 0}; - const SidePair rounded_dims{packed_lhs.layout.cols, - packed_rhs.layout.cols}; - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - params->RunPack(side, tuning, origin[side], rounded_dims[side]); - } - } - params->RunKernel(tuning, origin, rounded_dims); - - allocator->FreeAll(); - return; - } - - profiler::ScopeLabel label_general("TrMulImpl, general case"); - - auto* trace = NewTraceOrNull(&context->tracing, rows, depth, cols); - TraceRecordStart(trace); - - // Initialize block map. - BlockMap block_map; - MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth, - packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols, - packed_lhs.data_type.size, packed_rhs.data_type.size, - tentative_thread_count, params->path, - params->local_data_cache_size, params->shared_data_cache_size, - &block_map); - - // Initialize per-thread state. - const int thread_count = block_map.thread_count; - const bool need_atomics = thread_count > 1; - context->EnsureNPerThreadStates(thread_count); - for (auto& per_thread_state : context->per_thread_states) { - per_thread_state->tuning_resolver.SetTuning(context->explicit_tuning); - } - - // In the need_atomics case, allocate and initialize atomic values tracking - // the packing status of blocks. - SidePair*> packing_status{nullptr, nullptr}; - if (need_atomics) { - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - const int size = NumBlocksPerSide(side, block_map); - allocator->Allocate(size, &packing_status[side]); - for (int i = 0; i < size; i++) { - packing_status[side][i].store(PackingStatus::kNotStarted, - std::memory_order_relaxed); - } - } - } - } - - // Create the atomic block id, allocate it using Allocator so that - // we get the alignment ensuring that it sits alone in its exclusives - // reservation granule. - std::atomic* atomic_block_id; - allocator->Allocate(1, &atomic_block_id); - - // Create task objects. - TrMulTask* tasks; - allocator->Allocate(thread_count, &tasks); - - atomic_block_id->store(thread_count); - - for (int i = 0; i < thread_count; i++) { - new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i, - need_atomics, packing_status, - &context->per_thread_states[i]->tuning_resolver, - &context->per_thread_states[i]->allocator, trace); - } - - // Do the computation. - TraceRecordExecute(block_map, trace); - context->workers_pool.Execute(thread_count, tasks); - - // Finish up. - for (int i = 0; i < thread_count; i++) { - tasks[i].~TrMulTask(); - } - - allocator->FreeAll(); - TraceRecordEnd(trace); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/trmul.h b/tensorflow/lite/experimental/ruy/ruy/trmul.h deleted file mode 100644 index 9786b7f6180..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trmul.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// As a matrix multiplication library, Ruy offers a Mul entry point, performing -// matrix multiplication. For implementation purposes, it is much nicer to -// be dealing with the transpose-and-multiply operation, doing -// Destination = Transpose(LHS) * RHS -// Indeed, the latter is performing dot-products between the *columns* of LHS -// and the columns of RHS, whereas a plain matrix multiplication is performing -// dot-products between the *rows* of LHS and the columns of RHS. -// That is why TrMul is nicer to implement, allowing for a more symmetric -// treatment of LHS and RHS. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/trmul_params.h" - -namespace ruy { - -void TrMul(TrMulParams* params, Context* context); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/trmul_params.h b/tensorflow/lite/experimental/ruy/ruy/trmul_params.h deleted file mode 100644 index c694f16b938..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/trmul_params.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/internal_matrix.h" -#include "tensorflow/lite/experimental/ruy/ruy/side_pair.h" -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -namespace ruy { - -using RunKernelFn = void(Tuning, const SidePair&, void*, - const SidePair&, const SidePair&, DMatrix*); - -using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int); - -// Type-erased data needed for implementing TrMul. -struct TrMulParams { - TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {} - // Helper functions for invoking the function pointers. - void RunPack(Side side, Tuning tuning, int start, int end) { - run_pack[side](tuning, src[side], &packed[side], start, end); - } - void RunKernel(Tuning tuning, const SidePair& start, - const SidePair& end) { - run_kernel(tuning, packed, spec, start, end, &dst); - } - - // path id, can be useful info for some fine-tuning, e.g. to guess reasonable - // cache sizes when not runtime-detectable. - Path path; - - // See Spec::local_data_cache_size(). - int local_data_cache_size = 0; - // See Spec::shared_data_cache_size(). - int shared_data_cache_size = 0; - - // Function pointers to type-erased entry points for kernels and packers. - SidePair run_pack; - RunKernelFn* run_kernel = nullptr; - - // Matrices and packed matrices. - SidePair src; - DMatrix dst; - SidePair packed; - SidePair is_prepacked; - - // Type-erased Spec. - void* spec = nullptr; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/tune.cc b/tensorflow/lite/experimental/ruy/ruy/tune.cc deleted file mode 100644 index 63fa0338d6d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -#include -#include - -namespace ruy { - -#ifdef RUY_IMPLEMENT_TUNING - -namespace { - -void PoorlyOrderedKernel(int iters) { - asm volatile( - "mov w0, %w[iters]\n" - "1:\n" - "subs w0, w0, #1\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "bne 1b\n" ::[iters] "r"(iters) - : "cc", "x0", "v0", "v1", "v2", "v3"); -} - -void NicelyOrderedKernel(int iters) { - asm volatile( - "mov w0, %w[iters]\n" - "1:\n" - "subs w0, w0, #1\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "bne 1b\n" ::[iters] "r"(iters) - : "cc", "x0", "v0", "v1", "v2", "v3"); -} - -} // namespace - -float TuningResolver::EvalRatio() { - // With the current settings, 400 iterations and 4 repeats, this test has - // a latency of roughly 80 microseconds on a Cortex-A53 at 1.4 GHz. - static constexpr int kLoopIters = 400; - static constexpr int kRepeats = 4; - - Duration timing_poorly_ordered = Duration::max(); - Duration timing_nicely_ordered = Duration::max(); - - for (int r = 0; r < kRepeats; r++) { - TimePoint t0 = Now(); - PoorlyOrderedKernel(kLoopIters); - TimePoint t1 = Now(); - NicelyOrderedKernel(kLoopIters); - TimePoint t2 = Now(); - timing_poorly_ordered = std::min(timing_poorly_ordered, t1 - t0); - timing_nicely_ordered = std::min(timing_nicely_ordered, t2 - t1); - } - - return ToFloatSeconds(timing_nicely_ordered) / - ToFloatSeconds(timing_poorly_ordered); -} - -float TuningResolver::ThresholdRatio() { - // Empirically (see :tune_tool) determined threshold to distinguish in-order - // Cortex-A53/A55 cores from out-of-order Cortex-A57/A73/A75/A76 cores. Based - // on these experimental results, which were obtained with much lower - // (kLoopIters=1000, kRepeats=1) so as to make them resilient to noise, we - // have: - // - // CPU core type | in/out of order | observed ratio - // --------------+-----------------+----------------------------------------- - // Cortex-A53 | in-order | 0.32 -- 0.329 - // Cortex-A55 | in-order | 0.319 -- 0.325 - // Cortex-A55r1 | in-order | 0.319 -- 0.325 - // Cortex-A57 | out-of-order | 0.99 -- 1.01 - // Cortex-A73 | out-of-order | 0.922 -- 0.927 - // Cortex-A75 | out-of-order | 0.921 -- 0.93 - // Cortex-A76 | out-of-order | 1 - // Kryo (pixel1) | out-of-order | 0.73 -- 0.76 - // - // Thus the allowable range for the threshold is [0.35 .. 0.70]. - // We pick a value closer to the upper bound because really any out-of-order - // CPU should by definition produce a ratio close to 1. - return 0.65f; -} - -Tuning TuningResolver::ResolveNow() { - const bool is_probably_inorder = EvalRatio() < ThresholdRatio(); - return is_probably_inorder ? Tuning::kInOrder : Tuning::kOutOfOrder; -} - -#else // not defined RUY_IMPLEMENT_TUNING - -float TuningResolver::EvalRatio() { return 0; } -float TuningResolver::ThresholdRatio() { return 0; } - -Tuning TuningResolver::ResolveNow() { return Tuning::kOutOfOrder; } - -#endif - -TuningResolver::TuningResolver() - : expiry_duration_(DurationFromMilliseconds(250)) {} - -Tuning TuningResolver::Resolve() { -#ifdef RUY_IMPLEMENT_TUNING - if (unresolved_tuning_ != Tuning::kAuto) { - return unresolved_tuning_; - } - TimePoint new_timepoint = CoarseNow(); - if (last_resolved_tuning_ != Tuning::kAuto && - (new_timepoint - last_resolved_timepoint_) < expiry_duration_) { - return last_resolved_tuning_; - } - last_resolved_timepoint_ = new_timepoint; - last_resolved_tuning_ = ResolveNow(); - return last_resolved_tuning_; -#else - return Tuning::kOutOfOrder; -#endif -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/tune.h b/tensorflow/lite/experimental/ruy/ruy/tune.h deleted file mode 100644 index 3471604e37a..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune.h +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Library doing minimal CPU detection to decide what to tune asm code for. -// -// # Tuning vs Path -// -// Tunings are merely local variations of optimized code paths, that are -// drop-in replacements for each other --- the input and output data layouts -// are identical. By contrast, what ruy calls a Path dictates its own -// data layouts. For example, Path::kNeonDotprod will use different -// layouts compared to Path::kNeon; but within each, different tunings -// will share that same layout. -// -// # Tuning is for now only based on 1 bit: OutOfOrder / InOrder -// -// In practice, each of our asm code paths only needs one bit information to -// decide on tuning: whether the CPU is out-of-order or in-order. -// That is because out-of-order CPUs are by definition relatively insensitive -// to small-scale asm details (which is what "tuning" is about); and for each -// asm code path, there tends to be one main in-order CPU architecture that -// we focus our tuning effort on. Examples: -// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod) -// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod) -// -// Because having tuned code paths is a compromise of efficiency gains -// versus implementation effort and code size, we are happy to stop at just this -// single bit of information, OutOfOrder/InOrder, at least in the current CPU -// landscape. This could change in the future. -// -// # Implementation notes and alternatives. -// -// The current implementation uses a nano-benchmark, see tune.cc. -// That is why it's quite expensive, making caching / -// statefulness necessary (see TuningResolver class comment). -// -// An interesting alternative, which was explained to us by Marat Dukhan -// (maratek@) after this was implemented, would be to use the -// getcpu(2) system call on Linux. This returns a -// numeric CPU identifier that could be mapped to a OutOfOrder/InOrder -// classification given additional information about the CPU. Such -// additional information could be obtained by the cpuinfo library, -// https://github.com/pytorch/cpuinfo -// which obtains this information mainly from parsing /proc/cpuinfo. -// Pros: -// * Would remove the need for the relatively expensive nano-benchmark -// (dozens of microseconds, which have to be reevaluated again several -// times per second). -// * Would conceivably be more reliable. -// Cons: -// * Linux-specific. -// * Modest binary size increase (Marat mentioned the cpuinfo lib is 20k). -// * Won't support exactly 100% of devices (nonstandard /proc/cpuinfo etc). -// -// We could also have both: -// * Maybe by trying getcpu first if supported, then falling back to a -// nano-benchmark. -// * Maybe using getcpu in conjunction with the nano-benchmark to cache -// per-CPU-id nano-benchmark results. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ - -#include "tensorflow/lite/experimental/ruy/ruy/opt_set.h" -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -// Tuning only implemented on NEON_64 at the moment (see assembly code -// in the nano-benchmark) and not on Apple (some Apple CPUs produce incorrect -// results on in-order-tuned kernels combining ARM and NEON load instructions -// and NEON `ins` instructions). -// -// When tuning is not implemented, we simply always use Tuning::kOutOfOrder. -#if RUY_OPT_ENABLED(RUY_OPT_TUNING) && RUY_PLATFORM(NEON_64) && \ - !RUY_PLATFORM(APPLE) -#define RUY_IMPLEMENT_TUNING -#endif - -namespace ruy { - -enum class Tuning { - // kAuto means please use auto-detection. It's the default in the - // user-visible parts (see Context). It's meant to be resolved to an - // actual tuning at some point by means of TuningResolver. - kAuto, - // Target an out-order CPU. Example: ARM Cortex-A75. - kOutOfOrder, - // Target an in-order CPU. Example: ARM Cortex-A55. - kInOrder -}; - -// Why a TuningResolver class? -// -// Ideally, this Library would offer a single function, -// Tuning GetCurrentCPUTuning(); -// -// However, determining information about the current CPU is not necessarily, -// cheap, so we currently cache that and only invalidate/reevaluate after -// a fixed amount of time. This need to store state is why this library -// has to expose a class, TuningResolver, not just a function. -class TuningResolver { - public: - TuningResolver(); - - // Allows the user to specify an explicit Tuning value, bypassing auto - // detection; or to specify Tuning::kAuto, reverting to auto detection. - void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; } - - // Get an actual tuning --- that is the function that this class wanted to be. - Tuning Resolve(); - - private: - TuningResolver(const TuningResolver&) = delete; - - // TuningTool is a demo/tool used to tweak the tuning implementation to - // specific devices. It needs to access some finer granularity information - // than just the Tuning returned by Resolve. Nothing else should need - // access to that. - friend class TuneTool; - // Actually runs a nano-benchmark, producing a real number called 'ratio' - // whose meaning is generally opaque / implementation defined. Typically, - // this would be the ratio between the latencies of two different - // pieces of asm code differing only by the ordering of instructions, - // revealing whether the CPU cares about such ordering details. - // An implementation may just return a dummy value if it is not based on - // such nanobenchmarking / ratio evaluation. - float EvalRatio(); - // Empirically determined threshold on ratio values delineating - // out-of-order (ratios closer to 1) from in-order (ratios farther from 1). - // An implementation may just return a dummy value if it is not based on - // such nanobenchmarking / ratio evaluation. - float ThresholdRatio(); - // Perform the tuning resolution now. That may typically use EvalRatio and - // ThresholdRatio, but an implementation may use a different approach instead. - Tuning ResolveNow(); - - // The tuning as specified by the user, before actual resolution happens - // i.e. before querying any specifics of the current CPU. - // The default value kAuto means try to auto-detect. Other values mean - // bypass auto-detect, use explicit value instead. See SetTuning(). - Tuning unresolved_tuning_ = Tuning::kAuto; - // Cached last resolved tuning. - Tuning last_resolved_tuning_ = Tuning::kAuto; - // Timepoint of cached last resolved tuning, for invalidation purposes. - TimePoint last_resolved_timepoint_; - // Cached last resolved tunings that are older than this age are invalid. - const Duration expiry_duration_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/tune_test.cc b/tensorflow/lite/experimental/ruy/ruy/tune_test.cc deleted file mode 100644 index 0b00e645195..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include - -namespace ruy { -namespace { - -TEST(TuneTest, TuneTest) { - TuningResolver tuning_resolver; - ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); - // 1 second is likely higher than TuningResolver's internal cache expiry, - // exercising the logic invalidating earlier tuning resolutions. - std::this_thread::sleep_for(std::chrono::seconds(1)); - ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); - - tuning_resolver.SetTuning(Tuning::kAuto); - -#ifdef RUY_IMPLEMENT_TUNING - for (auto tuning : {Tuning::kOutOfOrder, Tuning::kInOrder}) { - tuning_resolver.SetTuning(tuning); - ASSERT_TRUE(tuning_resolver.Resolve() == tuning); - // See above comment about 1 second. - std::this_thread::sleep_for(std::chrono::seconds(1)); - ASSERT_TRUE(tuning_resolver.Resolve() == tuning); - } -#endif -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc b/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc deleted file mode 100644 index 04cfa6d6b89..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/tune_tool.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Self-contained tool used to tune the tune code --- see the -// threshold ratios used in tune.cc. - -#include // NOLINT(build/c++11) -#include -#include // NOLINT(build/c++11) - -#include "tensorflow/lite/experimental/ruy/ruy/tune.h" - -#ifdef _WIN32 -#define getpid() 0 -#else -#include -#endif - -namespace ruy { - -class TuneTool { - public: - static void Query(float* eval, float* threshold) { - TuningResolver resolver; - *eval = resolver.EvalRatio(); - *threshold = resolver.ThresholdRatio(); - } -}; - -} // namespace ruy - -int main() { - // Infinite loop: the user can hit Ctrl-C - while (true) { - float eval; - float threshold; - ruy::TuneTool::Query(&eval, &threshold); - printf("[%d] eval=%.3f %c threshold=%.3f ==> probably %s...\n", getpid(), - eval, eval < threshold ? '<' : '>', threshold, - eval < threshold ? "in-order" : "out-of-order"); - fflush(stdout); - std::this_thread::sleep_for(std::chrono::seconds(1)); - } -} diff --git a/tensorflow/lite/experimental/ruy/ruy/wait.cc b/tensorflow/lite/experimental/ruy/ruy/wait.cc deleted file mode 100644 index 7d91b6ebce6..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/wait.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -#include // NOLINT(build/c++11) - -namespace ruy { - -void Wait(const std::function& condition, const Duration& spin_duration, - std::condition_variable* condvar, std::mutex* mutex) { - // First, trivial case where the `condition` is already true; - if (condition()) { - return; - } - - // Then try busy-waiting. - const TimePoint wait_start = Now(); - while (Now() - wait_start < spin_duration) { - if (condition()) { - return; - } - } - - // Finally, do real passive waiting. - std::unique_lock lock(*mutex); - condvar->wait(lock, condition); -} - -void Wait(const std::function& condition, - std::condition_variable* condvar, std::mutex* mutex) { - // This value was empirically derived with some microbenchmark, we don't have - // high confidence in it. - // - // TODO(b/135595069): make this value configurable at runtime. - // I almost wanted to file another bug to ask for experimenting in a more - // principled way to tune this value better, but this would have to be tuned - // on real end-to-end applications and we'd expect different applications to - // require different tunings. So the more important point is the need for - // this to be controllable by the application. - // - // That this value means that we may be sleeping substantially longer - // than a scheduler timeslice's duration is not necessarily surprising. The - // idea is to pick up quickly new work after having finished the previous - // workload. When it's new work within the same GEMM as the previous work, the - // time interval that we might be busy-waiting is very small, so for that - // purpose it would be more than enough to sleep for 1 ms. - // That is all what we would observe on a GEMM benchmark. However, in a real - // application, after having finished a GEMM, we might do unrelated work for - // a little while, then start on a new GEMM. In that case the wait interval - // may be a little longer. There may also not be another GEMM for a long time, - // in which case we'll end up passively waiting below. - const Duration spin_duration = DurationFromMilliseconds(2); - Wait(condition, spin_duration, condvar, mutex); -} - -} // namespace ruy diff --git a/tensorflow/lite/experimental/ruy/ruy/wait.h b/tensorflow/lite/experimental/ruy/ruy/wait.h deleted file mode 100644 index a3cd26282af..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/wait.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ - -#include // NOLINT(build/c++11) -#include -#include // NOLINT(build/c++11) - -#include "tensorflow/lite/experimental/ruy/ruy/time.h" - -namespace ruy { - -// Waits until some evaluation of `condition` has returned true. -// -// There is no guarantee that calling `condition` again after this function -// has returned would still return true. The only -// contract is that at some point during the execution of that function, -// `condition` has returned true. -// -// First does some spin-waiting for the specified `spin_duration`, -// then falls back to passive waiting for the given condvar, guarded -// by the given mutex. At this point it will try to acquire the mutex lock, -// around the waiting on the condition variable. -// Therefore, this function expects that the calling thread hasn't already -// locked the mutex before calling it. -// This function will always release the mutex lock before returning. -// -// The idea of doing some initial spin-waiting is to help get -// better and more consistent multithreading benefits for small GEMM sizes. -// Spin-waiting help ensuring that if we need to wake up soon after having -// started waiting, then we can wake up quickly (as opposed to, say, -// having to wait to be scheduled again by the OS). On the other hand, -// we must still eventually revert to passive waiting for longer waits -// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) -// so as to avoid permanently spinning. -// -// In situations where other threads might have more useful things to do with -// these CPU cores than our spin-waiting, it may be best to reduce the value -// of `spin_duration`. Setting it to zero disables the spin-waiting entirely. -// -// There is a risk that the std::function used here might use a heap allocation -// to store its context. The expected usage pattern is that these functions' -// contexts will consist of a single pointer value (typically capturing only -// [this]), and that in this case the std::function implementation will use -// inline storage, avoiding a heap allocation. However, we can't effectively -// guard that assumption, and that's not a big concern anyway because the -// latency of a small heap allocation is probably low compared to the intrinsic -// latency of what this Wait function does. -void Wait(const std::function& condition, const Duration& spin_duration, - std::condition_variable* condvar, std::mutex* mutex); - -// Convenience overload using a default `spin_duration`. -// TODO(benoitjacob): let this be controlled from the ruy API. -void Wait(const std::function& condition, - std::condition_variable* condvar, std::mutex* mutex); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ diff --git a/tensorflow/lite/experimental/ruy/ruy/wait_test.cc b/tensorflow/lite/experimental/ruy/ruy/wait_test.cc deleted file mode 100644 index b1b7558583d..00000000000 --- a/tensorflow/lite/experimental/ruy/ruy/wait_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/ruy/ruy/wait.h" - -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include -#include "tensorflow/lite/experimental/ruy/ruy/platform.h" - -namespace ruy { -namespace { - -// Thread taking a `value` atomic counter and incrementing it until it equals -// `end_value`, then notifying the condition variable as long as -// `value == end_value`. If `end_value` is increased, it will then resume -// incrementing `value`, etc. Terminates if `end_value == -1`. -class ThreadCountingUpToValue { - public: - ThreadCountingUpToValue(const std::atomic& end_value, - std::atomic* value, - std::condition_variable* condvar, std::mutex* mutex) - : end_value_(end_value), - value_(value), - condvar_(condvar), - mutex_(mutex) {} - void operator()() { - // end_value_==-1 is how the master thread will tell us it's OK to terminate - while (end_value_.load() != -1) { - // wait until end_value is set to a higher value - while (value_->load() == end_value_.load()) { - } - // increment value as long as it's lower than end_value - while (value_->fetch_add(1) < end_value_.load() - 1) { - } - // when value has reached end_value, notify the master thread. - while (value_->load() == end_value_.load()) { - std::lock_guard lock(*mutex_); - condvar_->notify_all(); - } - } - } - - private: - const std::atomic& end_value_; - std::atomic* value_; - std::condition_variable* condvar_; - std::mutex* mutex_; -}; - -void WaitTest(const Duration& spin_duration, const Duration& delay) { -#if RUY_PLATFORM(EMSCRIPTEN) - // b/139927184, std::thread constructor raises exception - return; -#endif - std::condition_variable condvar; - std::mutex mutex; - std::atomic value(0); - std::atomic end_value(0); - ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex); - std::thread thread(thread_callable); - std::this_thread::sleep_for(delay); - for (int i = 1; i < 10; i++) { - end_value.store(1000 * i); - const auto& condition = [&value, &end_value]() { - return value.load() == end_value.load(); - }; - ruy::Wait(condition, spin_duration, &condvar, &mutex); - EXPECT_EQ(value.load(), end_value.load()); - } - end_value.store(-1); - thread.join(); -} - -TEST(WaitTest, WaitTestNoSpin) { - WaitTest(DurationFromSeconds(0), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneMicrosecond) { - WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneMillisecond) { - WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneSecond) { - WaitTest(DurationFromSeconds(1), DurationFromSeconds(0)); -} - -// Testcase to consistently reproduce the hang in b/139062384. -TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) { - WaitTest(DurationFromSeconds(0), DurationFromSeconds(1)); -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 28eefb2895f..a4d188f34da 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -274,7 +274,7 @@ cc_library( # For now this unconditionally depends on both ruy and gemmlowp. # See the comment inside class CpuBackendContext on the # gemmlowp_context_ and ruy_context_ members. - "//tensorflow/lite/experimental/ruy/ruy:context", + "@ruy//ruy:context", "@gemmlowp", "//tensorflow/lite:external_cpu_backend_context", ], @@ -295,8 +295,8 @@ cc_library( # We only need to depend on gemmlowp when tflite_with_ruy # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. - "//tensorflow/lite/experimental/ruy/ruy:context", - "//tensorflow/lite/experimental/ruy/ruy:thread_pool", + "@ruy//ruy:context", + "@ruy//ruy:thread_pool", "@gemmlowp", ], ) @@ -334,9 +334,9 @@ cc_library( ":cpu_backend_threadpool", # Depend on ruy regardless of `tflite_with_ruy`. See the comment in # cpu_backend_gemm.h about why ruy is the generic path. - "//tensorflow/lite/experimental/ruy/ruy", - "//tensorflow/lite/experimental/ruy/ruy:path", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy", + "@ruy//ruy:path", + "@ruy//ruy/profiler:instrumentation", # We only need to depend on gemmlowp and Eigen when tflite_with_ruy # is false, but putting these dependencies in a select() seems to # defeat copybara's rewriting rules. @@ -355,7 +355,7 @@ cc_test( "@com_google_googletest//:gtest", # ruy's reference path provides the reference implementation # that this test compares against. - "//tensorflow/lite/experimental/ruy/ruy", + "@ruy//ruy", ], ) @@ -596,11 +596,11 @@ cc_library( "//tensorflow/lite:context", "//tensorflow/lite/c:common", "//tensorflow/lite/experimental/kernels:hashtable_op_kernels", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:tensor", "//third_party/fft2d:fft2d_headers", "@fft2d", + "@ruy//ruy/profiler:instrumentation", ], ) @@ -613,13 +613,13 @@ cc_library( ":cpu_backend_context", ":op_macros", "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", "//tensorflow/lite/kernels/internal:common", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:tensor_utils", + "@ruy//ruy/profiler:instrumentation", ], ) diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index dfeea5d0a64..51284214ee4 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "ruy/context.h" // from @ruy #include "tensorflow/lite/kernels/op_macros.h" namespace { diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index eafae75fc47..46abcd5e90f 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "ruy/context.h" // from @ruy #include "tensorflow/lite/external_cpu_backend_context.h" namespace tflite { diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h index 6fde100a4bf..f85a1715af2 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h @@ -35,7 +35,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h index 253c035688f..ad9bbb75ae5 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h @@ -22,7 +22,7 @@ limitations under the License. #include #include "public/gemmlowp.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h index c02dce2b773..d038c03ac04 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ -#include "tensorflow/lite/experimental/ruy/ruy/path.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/path.h" // from @ruy +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc index d26df809c97..75181a979eb 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/cpu_backend_threadpool.h b/tensorflow/lite/kernels/cpu_backend_threadpool.h index b924826a07c..ff03d372d5e 100644 --- a/tensorflow/lite/kernels/cpu_backend_threadpool.h +++ b/tensorflow/lite/kernels/cpu_backend_threadpool.h @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #ifdef TFLITE_WITH_RUY -#include "tensorflow/lite/experimental/ruy/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy/thread_pool.h" +#include "ruy/context.h" // from @ruy +#include "ruy/thread_pool.h" // from @ruy #else #include "public/gemmlowp.h" #endif diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 952073ef02a..373fffd8c24 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -249,7 +249,7 @@ cc_library( ":transpose_utils", "//third_party/eigen3", "@gemmlowp//:fixedpoint", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_threadpool", @@ -301,7 +301,7 @@ cc_library( "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_threadpool", "//tensorflow/lite/kernels:cpu_backend_gemm", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -477,7 +477,7 @@ cc_library( "//third_party/eigen3", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite/tools/optimize/sparsity:format_converter", ] + select({ ":haswell": tflite_deps_intel, @@ -542,7 +542,7 @@ cc_library( "@gemmlowp", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/experimental/ruy/ruy/profiler:instrumentation", + "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite/tools/optimize/sparsity:format_converter", ] + select({ ":haswell": tflite_deps_intel, @@ -626,10 +626,10 @@ cc_library( ":cpu_check", ":portable_tensor_utils", "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/ruy/ruy", - "//tensorflow/lite/experimental/ruy/ruy:detect_arm", "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_gemm", + "@ruy//ruy", + "@ruy//ruy:detect_arm", ], ) @@ -822,10 +822,10 @@ cc_test( ":reference_base", ":test_util", ":types", - "//tensorflow/lite/experimental/ruy/ruy:context", "//tensorflow/lite/kernels:cpu_backend_context", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", + "@ruy//ruy:context", ], ) diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index 4f8ceb33595..1f2f6d57b9a 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/context.h" +#include "ruy/context.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h index 2768344696d..916edd561ff 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h index af763377763..a8f41d5a108 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index 1b86d91fb42..3f93a491862 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 293fd4248f2..73acbcf707b 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" diff --git a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h index e3a9b9acdc6..42aa4825771 100644 --- a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h index 8db98cf1bdc..a9dae4feac5 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_ADD_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h index 6c1abaeff82..61f848c888e 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h index d44cfabe3c3..ffc7ea84340 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h index 97039e2e462..0cb1a23e556 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h index 153a2252f39..37e9261b04a 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_HYBRID_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_HYBRID_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h index fa96ce94a6e..51f3d2559db 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h index fdd3135097b..8de99c1a564 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h index 952415593a5..18aeef4c8b5 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h" diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h index fb4642e7f0d..060845f4a10 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h @@ -28,7 +28,7 @@ limitations under the License. #include #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index af9ffba2c7c..dc2204e3a60 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "tensorflow/lite/experimental/ruy/ruy/detect_arm.h" -#include "tensorflow/lite/experimental/ruy/ruy/ruy.h" +#include "ruy/detect_arm.h" // from @ruy +#include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index d98c51d1a2f..6b50be31bfb 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -38,8 +38,8 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "fixedpoint/fixedpoint.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h index 20571110005..a815c3f5252 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/reference/reduce.h b/tensorflow/lite/kernels/internal/reference/reduce.h index 46448b2a646..17dfd8557ae 100644 --- a/tensorflow/lite/kernels/internal/reference/reduce.h +++ b/tensorflow/lite/kernels/internal/reference/reduce.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 56443bb2139..acf3c701d69 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -28,8 +28,8 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "fixedpoint/fixedpoint.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/add.h" diff --git a/tensorflow/lite/kernels/internal/reference/requantize.h b/tensorflow/lite/kernels/internal/reference/requantize.h index 8233be9ebae..32e32ed0d5b 100644 --- a/tensorflow/lite/kernels/internal/reference/requantize.h +++ b/tensorflow/lite/kernels/internal/reference/requantize.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REQUANTIZE_H_ -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/types.h" diff --git a/tensorflow/lite/kernels/internal/reference/sub.h b/tensorflow/lite/kernels/internal/reference/sub.h index a9ed3a675fd..48d03de02ee 100644 --- a/tensorflow/lite/kernels/internal/reference/sub.h +++ b/tensorflow/lite/kernels/internal/reference/sub.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SUB_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 1db812b251f..9895c9183ec 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" diff --git a/tensorflow/lite/kernels/rfft2d.cc b/tensorflow/lite/kernels/rfft2d.cc index c0554c5e39b..fa201153daf 100644 --- a/tensorflow/lite/kernels/rfft2d.cc +++ b/tensorflow/lite/kernels/rfft2d.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "third_party/fft2d/fft2d.h" +#include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" diff --git a/tensorflow/lite/micro/examples/person_detection/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/Makefile.inc index ca95f736cd4..a295bb83f71 100644 --- a/tensorflow/lite/micro/examples/person_detection/Makefile.inc +++ b/tensorflow/lite/micro/examples/person_detection/Makefile.inc @@ -1,4 +1,5 @@ $(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,)) +$(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,)) person_detection_MODEL_SRCS := \ tensorflow/lite/micro/examples/person_detection/model_settings.cc \ diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index bb5f3b74eef..e8b44bcbea6 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -48,20 +48,23 @@ INCLUDES := \ -I. \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ --I$(MAKEFILE_DIR)/downloads/flatbuffers/include +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(MAKEFILE_DIR)/downloads/ruy # Same list of paths, but now relative to the generated project files. GENERATED_PROJECT_INCLUDES := \ -I. \ -I./third_party/gemmlowp \ --I./third_party/flatbuffers/include +-I./third_party/flatbuffers/include \ +-I./third_party/ruy # Same list of paths, but now in the format the generate_keil_project.py # script expects them. PROJECT_INCLUDES := \ . \ third_party/gemmlowp \ -third_party/flatbuffers/include +third_party/flatbuffers/include \ +third_party/ruy TEST_SCRIPT := tensorflow/lite/micro/testing/test_linux_binary.sh @@ -132,7 +135,6 @@ tensorflow/lite/core/api/error_reporter.h \ tensorflow/lite/core/api/flatbuffer_conversions.h \ tensorflow/lite/core/api/op_resolver.h \ tensorflow/lite/core/api/tensor_utils.h \ -tensorflow/lite/experimental/ruy/ruy/profiler/instrumentation.h \ tensorflow/lite/kernels/internal/common.h \ tensorflow/lite/kernels/internal/compatibility.h \ tensorflow/lite/kernels/internal/optimized/neon_check.h \ @@ -192,7 +194,9 @@ third_party/gemmlowp/LICENSE \ third_party/flatbuffers/include/flatbuffers/base.h \ third_party/flatbuffers/include/flatbuffers/stl_emulation.h \ third_party/flatbuffers/include/flatbuffers/flatbuffers.h \ -third_party/flatbuffers/LICENSE.txt +third_party/flatbuffers/LICENSE.txt \ +third_party/ruy/ruy/profiler/instrumentation.h + MAKE_PROJECT_FILES := \ Makefile \ diff --git a/tensorflow/lite/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/micro/tools/make/third_party_downloads.inc index c4ff652a0ff..189d758eb96 100644 --- a/tensorflow/lite/micro/tools/make/third_party_downloads.inc +++ b/tensorflow/lite/micro/tools/make/third_party_downloads.inc @@ -48,6 +48,9 @@ SIFIVE_FE310_LIB_MD5 := "06ee24c4956f8e21670ab3395861fe64" KISSFFT_URL="https://github.com/mborgerding/kissfft/archive/v130.zip" KISSFFT_MD5="438ba1fef5783cc5f5f201395cc477ca" +RUY_URL="https://github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip" +RUY_MD5="5e653ae8863408ede2a0ca104fea5b1e" + PERSON_MODEL_URL := "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2019_11_21.zip" PERSON_MODEL_MD5 := "fe2934bd0788f1dcc7af3f0a954542ab" diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 1dd7e928c20..d10c1acb95d 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -154,7 +154,7 @@ cc_library( "@com_google_absl//absl/strings", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", - "//tensorflow/lite/experimental/ruy/ruy/profiler", + "@ruy//ruy/profiler", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/profiling:profiler", diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 617976991e1..7759aa7f53b 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/strings/numbers.h" -#include "tensorflow/lite/experimental/ruy/ruy/profiler/profiler.h" +#include "ruy/profiler/profiler.h" // from @ruy #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index 9043d494235..f5b86d0ff25 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -36,6 +36,7 @@ INCLUDES := \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/absl \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/ruy \ -I$(MAKEFILE_DIR)/downloads/neon_2_sse \ -I$(MAKEFILE_DIR)/downloads/farmhash/src \ -I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ @@ -118,8 +119,7 @@ $(wildcard tensorflow/lite/*.c) \ $(wildcard tensorflow/lite/c/*.c) \ $(wildcard tensorflow/lite/core/*.cc) \ $(wildcard tensorflow/lite/core/api/*.cc) \ -$(wildcard tensorflow/lite/experimental/resource/*.cc) \ -$(wildcard tensorflow/lite/experimental/ruy/ruy/*.cc) +$(wildcard tensorflow/lite/experimental/resource/*.cc) ifneq ($(BUILD_TYPE),micro) CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/lite/kernels/*.cc) \ diff --git a/tensorflow/lite/tools/make/download_dependencies.sh b/tensorflow/lite/tools/make/download_dependencies.sh index 2156feafef0..ea8e41b37d6 100755 --- a/tensorflow/lite/tools/make/download_dependencies.sh +++ b/tensorflow/lite/tools/make/download_dependencies.sh @@ -37,6 +37,8 @@ EIGEN_URL="$(grep -o 'https.*gitlab.com/libeigen/eigen/-/archive/.*tar\.gz' "${B EIGEN_SHA="$(eval echo $(grep '# SHARED_EIGEN_SHA' "${BZL_FILE_PATH}" | grep -o '\".*\"'))" GEMMLOWP_URL="$(grep -o 'https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GEMMLOWP_SHA="$(eval echo $(grep '# SHARED_GEMMLOWP_SHA' "${BZL_FILE_PATH}" | grep -o '\".*\"'))" +RUY_URL="https://github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip" +RUY_SHA="ac6d71df496a20043252f451d82a01636bb8bba9c3d6b5dc9fadadaffa392751" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" GOOGLETEST_SHA="58a6f4277ca2bc8565222b3bbd58a177609e9c488e8a72649359ba51450db7d8" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 4f7ce00eb53..c49ba608fc0 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -366,6 +366,7 @@ do_external_licenses_check(){ -e "@com_github_googlecloudplatform_google_cloud_cpp//google" \ -e "@com_github_grpc_grpc//src/compiler" \ -e "@platforms//os" \ + -e "@ruy//" \ -v ${MISSING_LICENSES_FILE} > temp.txt mv temp.txt ${MISSING_LICENSES_FILE} @@ -383,6 +384,7 @@ do_external_licenses_check(){ -e "@com_github_googlecloudplatform_google_cloud_cpp//" \ -e "@embedded_jdk//" \ -e "^//$" \ + -e "@ruy//" \ -v ${EXTRA_LICENSES_FILE} > temp.txt mv temp.txt ${EXTRA_LICENSES_FILE} diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index e5186bc1f13..e152f9b6a22 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -171,6 +171,7 @@ filegroup( "//third_party/fft2d:LICENSE", "//third_party/hadoop:LICENSE.txt", "//third_party/icu/data:LICENSE", + "@ruy//:LICENSE", "@arm_neon_2_x86_sse//:LICENSE", "@astunparse_archive//:LICENSE", "@astor_archive//:LICENSE", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index bdfc46411c8..9c0b8e920f8 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -39,6 +39,7 @@ load("//third_party/kissfft:workspace.bzl", kissfft = "repo") load("//third_party/pasta:workspace.bzl", pasta = "repo") load("//third_party/psimd:workspace.bzl", psimd = "repo") load("//third_party/pthreadpool:workspace.bzl", pthreadpool = "repo") +load("//third_party/ruy:workspace.bzl", ruy = "repo") load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo") load("//third_party/vulkan_headers:workspace.bzl", vulkan_headers = "repo") load("//third_party/toolchains/remote_config:configs.bzl", "initialize_rbe_configs") @@ -65,6 +66,7 @@ def initialize_third_party(): pthreadpool() sobol_data() vulkan_headers() + ruy() # Sanitize a dependency so that it works correctly from code that includes # TensorFlow as a submodule. diff --git a/third_party/ruy/BUILD b/third_party/ruy/BUILD new file mode 100644 index 00000000000..3ded6314938 --- /dev/null +++ b/third_party/ruy/BUILD @@ -0,0 +1,8 @@ +# Ruy is not BLAS + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) diff --git a/third_party/ruy/workspace.bzl b/third_party/ruy/workspace.bzl new file mode 100644 index 00000000000..203b89aa7e9 --- /dev/null +++ b/third_party/ruy/workspace.bzl @@ -0,0 +1,15 @@ +"""Loads the ruy library, used by TensorFlow Lite.""" + +load("//third_party:repo.bzl", "third_party_http_archive") + +def repo(): + third_party_http_archive( + name = "ruy", + sha256 = "ac6d71df496a20043252f451d82a01636bb8bba9c3d6b5dc9fadadaffa392751", + strip_prefix = "ruy-91d62808498cea7ccb48aa59181e218b4ad05701", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip", + "https://github.com/google/ruy/archive/91d62808498cea7ccb48aa59181e218b4ad05701.zip", + ], + build_file = "//third_party/ruy:BUILD", + )